You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_auto_parallel_context.py 34 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Context of auto parallel"""
  16. import threading
  17. import mindspore.context as context
  18. import mindspore.log as logger
  19. from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
  20. from mindspore.parallel._ps_context import _is_role_pserver
  21. from mindspore._c_expression import AutoParallelContext
  22. from mindspore._checkparam import args_type_check, Validator
  23. _MAX_GROUP_NAME_LEN = 127
  24. _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
  25. _DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
  26. class _AutoParallelContext:
  27. """
  28. _AutoParallelContext is the environment in which operations are executed
  29. Note:
  30. Create a context through instantiating Context object is not recommended.
  31. Should use auto_parallel_context() to get the context since Context is singleton.
  32. """
  33. _instance = None
  34. _instance_lock = threading.Lock()
  35. def __init__(self):
  36. self._context_handle = AutoParallelContext.get_instance()
  37. self._dataset_strategy_using_str = True
  38. def __new__(cls):
  39. if cls._instance is None:
  40. cls._instance_lock.acquire()
  41. cls._instance = object.__new__(cls)
  42. cls._instance_lock.release()
  43. return cls._instance
  44. def check_context_handle(self):
  45. """
  46. Check context handle.
  47. Raises:
  48. ValueError: If the context handle is none.
  49. """
  50. if self._context_handle is None:
  51. raise ValueError("Context handle is none in context!!!")
  52. def set_device_num(self, device_num):
  53. """
  54. Set device num for auto parallel.
  55. Args:
  56. device_num (int): The device number.
  57. Raises:
  58. ValueError: If the device num is not in [1, 4096].
  59. """
  60. self.check_context_handle()
  61. if device_num < 1 or device_num > 4096:
  62. raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
  63. self._context_handle.set_device_num(device_num)
  64. def get_device_num(self):
  65. """Get device num."""
  66. self.check_context_handle()
  67. return self._context_handle.get_device_num()
  68. def set_global_rank(self, global_rank):
  69. """
  70. Set global rank for auto parallel.
  71. Args:
  72. global_rank (int): The rank id of current rank.
  73. Raises:
  74. ValueError: If the global rank is not in [1, 4096].
  75. """
  76. self.check_context_handle()
  77. if global_rank < 0 or global_rank > 4095:
  78. raise ValueError("Global rank must be in [0, 4095], but got {}".format(global_rank))
  79. self._context_handle.set_global_rank(global_rank)
  80. def get_global_rank(self):
  81. """Get current rank id."""
  82. self.check_context_handle()
  83. return self._context_handle.get_global_rank()
  84. def set_pipeline_stages(self, stages):
  85. """Set the stages of the pipeline"""
  86. if isinstance(stages, bool):
  87. raise TypeError("The type of pipeline_stage_num must be int, but got bool.")
  88. if not isinstance(stages, int):
  89. raise TypeError("The type of pipeline_stage_num must be int.")
  90. if stages < 1:
  91. raise ValueError("pipeline_stage_num can't be less than 1.")
  92. backend = context.get_context("device_target")
  93. if backend == "GPU" and stages > 1:
  94. raise RuntimeError("Now GPU don't support pipeline parallel.")
  95. self.check_context_handle()
  96. self._context_handle.set_pipeline_stage_split_num(stages)
  97. def get_pipeline_stages(self):
  98. """Get the stages of the pipeline"""
  99. self.check_context_handle()
  100. return self._context_handle.get_pipeline_stage_split_num()
  101. def set_gradients_mean(self, gradients_mean):
  102. """
  103. Set gradients_mean flag.
  104. Note:
  105. If gradients_mean is true, it will insert a div operator after parameter gradients allreduce.
  106. Args:
  107. gradients_mean (bool): The gradients_mean flag.
  108. """
  109. self.check_context_handle()
  110. self._context_handle.set_gradients_mean(gradients_mean)
  111. def get_gradients_mean(self):
  112. """Get gradients_mean flag."""
  113. self.check_context_handle()
  114. return self._context_handle.get_gradients_mean()
  115. def set_gradient_fp32_sync(self, gradient_fp32_sync):
  116. """
  117. Set gradient_fp32_sync.
  118. Note:
  119. If gradient_fp32_sync is true,
  120. it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
  121. Args:
  122. gradient_fp32_sync (bool): The gradient_fp32_sync flag.
  123. """
  124. self.check_context_handle()
  125. self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync)
  126. def get_gradient_fp32_sync(self):
  127. """Get gradient_fp32_sync flag."""
  128. self.check_context_handle()
  129. return self._context_handle.get_gradient_fp32_sync()
  130. def set_loss_repeated_mean(self, loss_repeated_mean):
  131. """
  132. Set loss_repeated_mean flag.
  133. Note:
  134. If loss_repeated_mean is true,
  135. Distributed automatic differentiation will perform a mean operator
  136. in backward in the case of repeated calculations.
  137. Args:
  138. loss_repeated_mean (bool): The loss_repeated_mean flag.
  139. """
  140. if not isinstance(loss_repeated_mean, bool):
  141. raise TypeError(f"The type of loss_repeated_mean must be bool, but got {type(loss_repeated_mean)}.")
  142. self.check_context_handle()
  143. self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
  144. def get_loss_repeated_mean(self):
  145. """Get loss_repeated_mean flag."""
  146. self.check_context_handle()
  147. return self._context_handle.get_loss_repeated_mean()
  148. def set_parallel_mode(self, parallel_mode):
  149. """
  150. Set parallel mode for auto parallel.
  151. Args:
  152. parallel_mode (str): The parallel mode of auto parallel.
  153. Raises:
  154. ValueError: If parallel mode is not supported.
  155. """
  156. self.check_context_handle()
  157. ret = self._context_handle.set_parallel_mode(parallel_mode)
  158. if ret is False:
  159. raise ValueError("Parallel mode does not support {}".format(parallel_mode))
  160. def get_parallel_mode(self):
  161. """Get parallel mode."""
  162. self.check_context_handle()
  163. if _is_role_pserver():
  164. return context.ParallelMode.STAND_ALONE
  165. return self._context_handle.get_parallel_mode()
  166. def set_strategy_search_mode(self, auto_parallel_search_mode):
  167. """
  168. Set search mode of strategy.
  169. Args:
  170. auto_parallel_search_mode (str): The search mode of strategy.
  171. """
  172. self.check_context_handle()
  173. ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode)
  174. if ret is False:
  175. raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode))
  176. def get_strategy_search_mode(self):
  177. """Get search mode of strategy."""
  178. self.check_context_handle()
  179. return self._context_handle.get_strategy_search_mode()
  180. def set_parameter_broadcast(self, parameter_broadcast):
  181. """
  182. Set parameter broadcast.
  183. Args:
  184. parameter_broadcast (bool): Parameter broadcast or not.
  185. """
  186. self.check_context_handle()
  187. self._context_handle.set_parameter_broadcast(parameter_broadcast)
  188. def get_parameter_broadcast(self):
  189. """Get parameter broadcast flag."""
  190. self.check_context_handle()
  191. return self._context_handle.get_parameter_broadcast()
  192. def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file):
  193. """
  194. Set strategy checkpoint load path.
  195. Args:
  196. strategy_ckpt_load_file (str): Path to load parallel strategy checkpoint.
  197. """
  198. self.check_context_handle()
  199. self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)
  200. def get_strategy_ckpt_load_file(self):
  201. """Get strategy checkpoint load path."""
  202. self.check_context_handle()
  203. return self._context_handle.get_strategy_ckpt_load_file()
  204. def set_full_batch(self, full_batch):
  205. """
  206. Set whether load full batch on each device.
  207. Args:
  208. full_batch (bool): True if load full batch on each device.
  209. """
  210. self.check_context_handle()
  211. self._context_handle.set_full_batch(full_batch)
  212. def get_full_batch(self):
  213. """Get whether load full batch on each device."""
  214. self.check_context_handle()
  215. if _is_role_pserver():
  216. return False
  217. return self._context_handle.get_full_batch()
  218. def set_dataset_strategy(self, dataset_strategy):
  219. """
  220. Set dataset sharding strategy.
  221. Args:
  222. dataset_strategy (str or tuple(tuple)): The dataset sharding strategy.
  223. """
  224. self.check_context_handle()
  225. if isinstance(dataset_strategy, str):
  226. if dataset_strategy not in ("full_batch", "data_parallel"):
  227. raise ValueError("The dataset_strategy string should be 'full_batch' or 'data_parallel', "
  228. "otherwise, incoming tuple(tuple) type strategy")
  229. self._context_handle.set_full_batch(dataset_strategy == "full_batch")
  230. self._dataset_strategy_using_str = True
  231. return
  232. if not isinstance(dataset_strategy, tuple):
  233. raise TypeError(f'strategy must be str or tuple type, but got:{type(dataset_strategy)}')
  234. for ele in dataset_strategy:
  235. if not isinstance(ele, tuple):
  236. raise TypeError(f'The element of strategy must be tuple type, but got:{type(ele)}')
  237. for dim in ele:
  238. if not isinstance(dim, int):
  239. raise TypeError(f'The dim of each strategy value must be int type, but got:{type(dim)}')
  240. self._dataset_strategy_using_str = False
  241. self._context_handle.set_dataset_strategy(dataset_strategy)
  242. def get_dataset_strategy(self):
  243. """Get dataset sharding strategy."""
  244. self.check_context_handle()
  245. if self._dataset_strategy_using_str:
  246. if self._context_handle.get_full_batch():
  247. return "full_batch"
  248. return "data_parallel"
  249. return self._context_handle.get_dataset_strategy()
  250. def set_grad_accumulation_step(self, grad_accumulation_step):
  251. """
  252. Set grad accumulation step.
  253. Args:
  254. grad_accumulation_step (int): The grad accumulation step.
  255. """
  256. self.check_context_handle()
  257. Validator.check_positive_int(grad_accumulation_step)
  258. self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
  259. def get_grad_accumulation_step(self):
  260. """Get grad accumulation step."""
  261. self.check_context_handle()
  262. return self._context_handle.get_grad_accumulation_step()
  263. def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
  264. """
  265. Set strategy checkpoint save path.
  266. Args:
  267. strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
  268. """
  269. self.check_context_handle()
  270. import os
  271. dir_path = os.path.dirname(strategy_ckpt_save_file)
  272. if dir_path and not os.path.exists(dir_path):
  273. os.makedirs(dir_path)
  274. self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
  275. def get_strategy_ckpt_save_file(self):
  276. """Get strategy checkpoint save path."""
  277. self.check_context_handle()
  278. return self._context_handle.get_strategy_ckpt_save_file()
  279. def set_group_ckpt_save_file(self, group_ckpt_save_file):
  280. """Set group checkpoint save path."""
  281. self.check_context_handle()
  282. import os
  283. dir_path = os.path.dirname(group_ckpt_save_file)
  284. if dir_path and not os.path.exists(dir_path):
  285. os.makedirs(dir_path)
  286. self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
  287. def get_parameter_broadcast_is_set(self):
  288. """Get parameter broadcast is set or not."""
  289. self.check_context_handle()
  290. return self._context_handle.get_parameter_broadcast_is_set()
  291. def set_all_reduce_fusion_split_indices(self, indices, group=""):
  292. """
  293. Set allreduce fusion strategy by parameters indices.
  294. Args:
  295. indices (list): Indices list.
  296. group (str): The communication group of hccl/nccl.
  297. Raises:
  298. TypeError: If type of indices item is not int.
  299. TypeError: If group is not a python str.
  300. """
  301. self.check_context_handle()
  302. if not indices:
  303. raise ValueError('indices can not be empty')
  304. if isinstance(indices, (list)):
  305. for index in indices:
  306. if not isinstance(index, int) or isinstance(index, bool):
  307. raise TypeError(f"The type of index must be int, but got {type(index)}.")
  308. else:
  309. raise TypeError('indices must be a python list')
  310. if len(set(indices)) != len(indices):
  311. raise ValueError('indices has duplicate elements')
  312. if sorted(indices) != indices:
  313. raise ValueError('elements in indices must be sorted in ascending order')
  314. new_group = self._check_and_default_group(group)
  315. self._context_handle.set_all_reduce_fusion_split_indices(indices, new_group)
  316. if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
  317. _set_fusion_strategy_by_idx(indices)
  318. def get_all_reduce_fusion_split_indices(self, group=""):
  319. """
  320. Get allreduce fusion split indices.
  321. Args:
  322. group (str): The communication group of hccl/nccl.
  323. Returns:
  324. Return split sizes list according to the group.
  325. Raises:
  326. TypeError: If group is not a python str.
  327. """
  328. self.check_context_handle()
  329. new_group = self._check_and_default_group(group)
  330. return self._context_handle.get_all_reduce_fusion_split_indices(new_group)
  331. def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
  332. """
  333. Set allreduce fusion strategy by parameters data sizes.
  334. Args:
  335. sizes (list): Sizes list.
  336. group (str): The communication group of hccl/nccl.
  337. Raises:
  338. TypeError: If type of sizes item is not int.
  339. TypeError: If group is not a python str.
  340. """
  341. self.check_context_handle()
  342. if isinstance(sizes, (list)):
  343. for size in sizes:
  344. if not isinstance(size, int) or isinstance(size, bool):
  345. raise TypeError(f"The type of size must be int, but got {type(size)}.")
  346. else:
  347. raise TypeError('sizes must be a python list')
  348. new_group = self._check_and_default_group(group)
  349. self._context_handle.set_all_reduce_fusion_split_sizes(sizes, new_group)
  350. if context.get_context("device_target") == "Ascend":
  351. _set_fusion_strategy_by_size(sizes)
  352. def get_all_reduce_fusion_split_sizes(self, group=""):
  353. """
  354. Get allreduce fusion split sizes.
  355. Args:
  356. group (str): The communication group of hccl/nccl.
  357. Returns:
  358. Return split sizes list according to the group.
  359. Raises:
  360. TypeError: If group is not a python str.
  361. """
  362. self.check_context_handle()
  363. new_group = self._check_and_default_group(group)
  364. return self._context_handle.get_all_reduce_fusion_split_sizes(new_group)
  365. def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
  366. """
  367. Set enable/disable all reduce fusion.
  368. Args:
  369. enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
  370. """
  371. self.check_context_handle()
  372. if not isinstance(enable_all_reduce_fusion, bool):
  373. raise TypeError('enable_all_reduce_fusion is invalid type')
  374. self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
  375. def get_enable_all_reduce_fusion(self):
  376. """Get all reduce fusion flag."""
  377. self.check_context_handle()
  378. return self._context_handle.get_enable_all_reduce_fusion()
  379. def get_device_num_is_set(self):
  380. """Get device number is set or not."""
  381. self.check_context_handle()
  382. return self._context_handle.get_device_num_is_set()
  383. def get_global_rank_is_set(self):
  384. """Get global rank is set or not."""
  385. self.check_context_handle()
  386. return self._context_handle.get_global_rank_is_set()
  387. def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
  388. """
  389. Set enable/disable parallel optimizer.
  390. Args:
  391. set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
  392. """
  393. self.check_context_handle()
  394. if not isinstance(enable_parallel_optimizer, bool):
  395. raise TypeError('enable_parallel_optimizer is invalid type')
  396. self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
  397. def get_enable_parallel_optimizer(self):
  398. """Get parallel optimizer flag."""
  399. self.check_context_handle()
  400. return self._context_handle.get_enable_parallel_optimizer()
  401. def set_sharding_propagation(self, sharding_propagation):
  402. """
  403. Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
  404. will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
  405. will search the desired strategies.
  406. Default: False.
  407. Args:
  408. sharding_propagation (bool): Enable/disable strategy propagation.
  409. """
  410. self.check_context_handle()
  411. if not isinstance(sharding_propagation, bool):
  412. raise TypeError("'sharding_propagation' is an invalid type.")
  413. self._context_handle.set_sharding_propagation(sharding_propagation)
  414. def get_sharding_propagation(self):
  415. """Get the value of sharding strategy propagation."""
  416. self.check_context_handle()
  417. return self._context_handle.get_sharding_propagation()
  418. def set_enable_alltoall(self, enable_a2a):
  419. """
  420. Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll.
  421. Default: False.
  422. Args:
  423. enable_a2a (bool): Enable/disable AllToAll.
  424. """
  425. self.check_context_handle()
  426. if not isinstance(enable_a2a, bool):
  427. raise TypeError("'enable_a2a' is an invalid type.")
  428. self._context_handle.set_enable_alltoall(enable_a2a)
  429. def get_enable_alltoall(self):
  430. """Get the value of enabling AllToAll."""
  431. self.check_context_handle()
  432. return self._context_handle.get_enable_alltoall()
  433. def set_communi_parallel_mode(self, communi_parallel_mode):
  434. """
  435. Set communication parallel mode.
  436. Args:
  437. communi_parallel_mode (str): The communication parallel mode.
  438. Raises:
  439. ValueError: If parallel mode is not supported.
  440. """
  441. if not isinstance(communi_parallel_mode, str):
  442. raise TypeError(f"The type of communi_parallel_mode must be str, \
  443. but got {type(communi_parallel_mode)}.")
  444. self.check_context_handle()
  445. ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode)
  446. if ret is False:
  447. raise ValueError("Communication parallel mode does not support {}".format(communi_parallel_mode))
  448. def get_communi_parallel_mode(self):
  449. """Get communication parallel mode."""
  450. self.check_context_handle()
  451. return self._context_handle.get_communi_parallel_mode()
  452. def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size):
  453. """
  454. Set optimizer_weight_shard_size.
  455. Args:
  456. optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel
  457. optimizer across devices.
  458. """
  459. self.check_context_handle()
  460. if not isinstance(optimizer_weight_shard_size, int) or isinstance(optimizer_weight_shard_size, bool):
  461. raise TypeError(f"The type of optimizer_weight_shard_size must be int, \
  462. but got {type(optimizer_weight_shard_size)}.")
  463. if optimizer_weight_shard_size <= 1:
  464. logger.warning("The setting 'optimizer_weight_shard_size' is invalid. "
  465. "Please use the integer larger than 1.")
  466. return
  467. self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size)
  468. def get_optimizer_weight_shard_size(self):
  469. """Get optimizer_weight_shard_size."""
  470. self.check_context_handle()
  471. return self._context_handle.get_optimizer_weight_shard_size()
  472. def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save):
  473. """
  474. Set optimizer_weight_shard_aggregated_save.
  475. Args:
  476. optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when
  477. enable parallel optimizer.
  478. """
  479. self.check_context_handle()
  480. if not isinstance(optimizer_weight_shard_aggregated_save, bool):
  481. raise TypeError('optimizer_weight_shard_aggregated_save is invalid type')
  482. self._context_handle.set_optimizer_weight_shard_aggregated_save(optimizer_weight_shard_aggregated_save)
  483. def get_optimizer_weight_shard_aggregated_save(self):
  484. """Get optimizer_weight_shard_size."""
  485. self.check_context_handle()
  486. return self._context_handle.get_optimizer_weight_shard_aggregated_save()
  487. def reset(self):
  488. """Reset all settings."""
  489. self.check_context_handle()
  490. self._context_handle.reset()
  491. def _check_and_default_group(self, group):
  492. """Validate the given group, if group is empty, returns a default fusion group"""
  493. if isinstance(group, (str)):
  494. group_len = len(group)
  495. if group_len > _MAX_GROUP_NAME_LEN:
  496. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  497. else:
  498. raise TypeError('Group must be a python str')
  499. if group == "":
  500. if context.get_context("device_target") == "Ascend":
  501. group = _DEFAULT_HCCL_FUSION_GROUP_NAME
  502. else:
  503. group = _DEFAULT_NCCL_FUSION_GROUP_NAME
  504. return group
  505. _auto_parallel_context = None
  506. def auto_parallel_context():
  507. """
  508. Get the global _auto_parallel_context, if it is not created, create a new one.
  509. Returns:
  510. _AutoParallelContext, the global auto parallel context.
  511. """
  512. global _auto_parallel_context
  513. if _auto_parallel_context is None:
  514. _auto_parallel_context = _AutoParallelContext()
  515. return _auto_parallel_context
  516. _set_auto_parallel_context_func_map = {
  517. "device_num": auto_parallel_context().set_device_num,
  518. "global_rank": auto_parallel_context().set_global_rank,
  519. "gradients_mean": auto_parallel_context().set_gradients_mean,
  520. "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
  521. "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
  522. "pipeline_stages": auto_parallel_context().set_pipeline_stages,
  523. "parallel_mode": auto_parallel_context().set_parallel_mode,
  524. "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
  525. "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
  526. "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
  527. "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
  528. "group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file,
  529. "full_batch": auto_parallel_context().set_full_batch,
  530. "dataset_strategy": auto_parallel_context().set_dataset_strategy,
  531. "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
  532. "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
  533. "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
  534. "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
  535. "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
  536. "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
  537. "sharding_propagation": auto_parallel_context().set_sharding_propagation,
  538. "enable_alltoall": auto_parallel_context().set_enable_alltoall}
  539. _get_auto_parallel_context_func_map = {
  540. "device_num": auto_parallel_context().get_device_num,
  541. "global_rank": auto_parallel_context().get_global_rank,
  542. "gradients_mean": auto_parallel_context().get_gradients_mean,
  543. "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
  544. "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
  545. "pipeline_stages": auto_parallel_context().get_pipeline_stages,
  546. "parallel_mode": auto_parallel_context().get_parallel_mode,
  547. "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
  548. "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
  549. "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
  550. "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
  551. "full_batch": auto_parallel_context().get_full_batch,
  552. "dataset_strategy": auto_parallel_context().get_dataset_strategy,
  553. "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
  554. "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
  555. "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
  556. "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
  557. "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
  558. "optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
  559. "sharding_propagation": auto_parallel_context().get_sharding_propagation,
  560. "enable_alltoall": auto_parallel_context().get_enable_alltoall}
  561. @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
  562. loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
  563. parameter_broadcast=bool, strategy_ckpt_load_file=str,
  564. strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
  565. grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
  566. communi_parallel_mode=str, optimizer_weight_shard_size=int,
  567. optimizer_weight_shard_aggregated_save=bool,
  568. sharding_propagation=bool, enable_alltoall=bool)
  569. def _set_auto_parallel_context(**kwargs):
  570. """
  571. Set auto parallel context.
  572. Note:
  573. Attribute name is required for setting attributes.
  574. Args:
  575. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
  576. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
  577. gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
  578. loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
  579. calculations. Default: True.
  580. gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
  581. Default: True.
  582. parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
  583. "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
  584. - stand_alone: Only one processor working.
  585. - data_parallel: Distributing the data across different processors.
  586. - hybrid_parallel: Achieving data parallelism and model parallelism manually.
  587. - semi_auto_parallel: Achieving data parallelism and model parallelism by
  588. setting parallel strategies.
  589. - auto_parallel: Achieving parallelism automatically.
  590. auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
  591. and "dynamic_programming". Default: "dynamic_programming".
  592. - recursive_programming: Recursive programming search mode.
  593. - dynamic_programming: Dynamic programming search mode.
  594. parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
  595. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
  596. broadcast. Default: False.
  597. strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
  598. strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
  599. group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: ''
  600. full_batch (bool): Whether to load the whole batch on each device. Default: False.
  601. dataset_strategy Union[str, tuple]: Dataset sharding strategy. Default: "data_parallel".
  602. enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
  603. all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
  604. pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
  605. the devices are distributed alone the pipeline. The total devices will be divided into
  606. 'pipeline_stags' stages. This currently could only be used when
  607. parallel mode semi_auto_parallel is enabled. Default: 0
  608. communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
  609. "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".
  610. - all_group_parallel: All communication groups are in parallel.
  611. - same_server_group_parallel: Only the communication groups within the same server are parallel.
  612. - no_group_parallel: All communication groups are not parallel.
  613. optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer.
  614. It should be larger than one and less than or equal with the data parallel size.
  615. Default: -1, which means fully use parallel optimizer in data parallel dimension.
  616. optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when enable parallel
  617. optimizer. Default: False.
  618. sharding_propagation (bool): Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True,
  619. the strategy-configured operators will propagate the strategies to other
  620. operators with minimum redistribution cost; otherwise, the algorithm will
  621. search the desired strategies. Default: False.
  622. enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to
  623. circumvent AllToAll. Default: False.
  624. Raises:
  625. ValueError: If input key is not attribute in auto parallel context.
  626. """
  627. for key, value in kwargs.items():
  628. if key not in _set_auto_parallel_context_func_map:
  629. raise ValueError("Set context keyword %s is not recognized!" % key)
  630. set_func = _set_auto_parallel_context_func_map[key]
  631. set_func(value)
  632. def _get_auto_parallel_context(attr_key):
  633. """
  634. Get auto parallel context attribute value according to the key.
  635. Args:
  636. attr_key (str): The key of the attribute.
  637. Returns:
  638. Return attribute value according to the key.
  639. Raises:
  640. ValueError: If input key is not attribute in auto parallel context.
  641. """
  642. if attr_key not in _get_auto_parallel_context_func_map:
  643. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  644. get_func = _get_auto_parallel_context_func_map[attr_key]
  645. return get_func()
  646. def _reset_auto_parallel_context():
  647. """
  648. Reset auto parallel context attributes to the default values:
  649. - device_num: 1.
  650. - global_rank: 0.
  651. - gradients_mean: False.
  652. - gradient_fp32_sync: True.
  653. - parallel_mode: "stand_alone".
  654. - parameter_broadcast: False.
  655. - strategy_ckpt_load_file: ""
  656. - strategy_ckpt_save_file: ""
  657. - enable_parallel_optimizer: False
  658. - auto_parallel_search_mode: dynamic_programming
  659. - pipeline_stages: 0
  660. """
  661. auto_parallel_context().reset()