You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_auto_parallel_context.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Context of auto parallel"""
  16. import threading
  17. import mindspore.context as context
  18. from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
  19. from mindspore._c_expression import AutoParallelContext
  20. from mindspore._checkparam import args_type_check
  21. _MAX_GROUP_NAME_LEN = 127
  22. class _AutoParallelContext:
  23. """
  24. _AutoParallelContext is the environment in which operations are executed
  25. Note:
  26. Create a context through instantiating Context object is not recommended.
  27. Should use auto_parallel_context() to get the context since Context is singleton.
  28. """
  29. _instance = None
  30. _instance_lock = threading.Lock()
  31. def __init__(self):
  32. self._context_handle = AutoParallelContext.get_instance()
  33. def __new__(cls):
  34. if cls._instance is None:
  35. cls._instance_lock.acquire()
  36. cls._instance = object.__new__(cls)
  37. cls._instance_lock.release()
  38. return cls._instance
  39. def check_context_handle(self):
  40. """
  41. Check context handle.
  42. Raises:
  43. ValueError: If the context handle is none.
  44. """
  45. if self._context_handle is None:
  46. raise ValueError("Context handle is none in context!!!")
  47. def set_device_num(self, device_num):
  48. """
  49. Set device num for auto parallel.
  50. Args:
  51. device_num (int): The device number.
  52. Raises:
  53. ValueError: If the device num is not in [1, 4096].
  54. """
  55. self.check_context_handle()
  56. if device_num < 1 or device_num > 4096:
  57. raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
  58. self._context_handle.set_device_num(device_num)
  59. def get_device_num(self):
  60. """Get device num."""
  61. self.check_context_handle()
  62. return self._context_handle.get_device_num()
  63. def set_global_rank(self, global_rank):
  64. """
  65. Set global rank for auto parallel.
  66. Args:
  67. global_rank (int): The rank id of current rank.
  68. Raises:
  69. ValueError: If the global rank is not in [1, 4096].
  70. """
  71. self.check_context_handle()
  72. if global_rank < 0 or global_rank > 4095:
  73. raise ValueError("Global rank must be in [0, 4095], but got {}".format(global_rank))
  74. self._context_handle.set_global_rank(global_rank)
  75. def get_global_rank(self):
  76. """Get current rank id."""
  77. self.check_context_handle()
  78. return self._context_handle.get_global_rank()
  79. def set_mirror_mean(self, mirror_mean):
  80. """
  81. Set mirror_mean flag.
  82. Note:
  83. If mirror_mean is true, it will insert a div operator after parameter gradients allreduce.
  84. Args:
  85. mirror_mean (bool): The mirror_mean flag.
  86. """
  87. self.check_context_handle()
  88. self._context_handle.set_mirror_mean(mirror_mean)
  89. def get_mirror_mean(self):
  90. """Get mirror_mean flag."""
  91. self.check_context_handle()
  92. return self._context_handle.get_mirror_mean()
  93. def set_cast_before_mirror(self, cast_before_mirror):
  94. """
  95. Set cast_before_mirror.
  96. Note:
  97. If cast_before_mirror is true,
  98. it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
  99. Args:
  100. cast_before_mirror (bool): The cast_before_mirror flag.
  101. """
  102. self.check_context_handle()
  103. self._context_handle.set_cast_before_mirror(cast_before_mirror)
  104. def get_cast_before_mirror(self):
  105. """Get cast_before_mirror flag."""
  106. self.check_context_handle()
  107. return self._context_handle.get_cast_before_mirror()
  108. def set_loss_repeated_mean(self, loss_repeated_mean):
  109. """
  110. Set loss_repeated_mean flag.
  111. Note:
  112. If loss_repeated_mean is true,
  113. Distributed automatic differentiation will perform a mean operator
  114. in backward in the case of repeated calculations.
  115. Args:
  116. loss_repeated_mean (bool): The loss_repeated_mean flag.
  117. """
  118. self.check_context_handle()
  119. self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
  120. def get_loss_repeated_mean(self):
  121. """Get loss_repeated_mean flag."""
  122. self.check_context_handle()
  123. return self._context_handle.get_loss_repeated_mean()
  124. def set_communication_backend(self, communication_backend):
  125. """
  126. Set communication backend.
  127. Args:
  128. communication_backend (str): The communication backend.
  129. """
  130. self.check_context_handle()
  131. self._context_handle.set_communication_backend(communication_backend)
  132. def get_communication_backend(self):
  133. """Get communication backend."""
  134. self.check_context_handle()
  135. return self._context_handle.get_communication_backend()
  136. def set_parallel_mode(self, parallel_mode):
  137. """
  138. Set parallel mode for auto parallel.
  139. Args:
  140. parallel_mode (str): The parallel mode of auto parallel.
  141. Raises:
  142. ValueError: If parallel mode is not supported.
  143. """
  144. self.check_context_handle()
  145. ret = self._context_handle.set_parallel_mode(parallel_mode)
  146. if ret is False:
  147. raise ValueError("Parallel mode does not support {}".format(parallel_mode))
  148. def get_parallel_mode(self):
  149. """Get parallel mode."""
  150. self.check_context_handle()
  151. return self._context_handle.get_parallel_mode()
  152. def set_strategy_search_mode(self, auto_parallel_search_mode):
  153. """
  154. Set search mode of strategy.
  155. Args:
  156. auto_parallel_search_mode (str): The search mode of strategy.
  157. """
  158. self.check_context_handle()
  159. ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode)
  160. if ret is False:
  161. raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode))
  162. def get_strategy_search_mode(self):
  163. """Get search mode of strategy."""
  164. self.check_context_handle()
  165. return self._context_handle.get_strategy_search_mode()
  166. def set_parameter_broadcast(self, parameter_broadcast):
  167. """
  168. Set parameter broadcast.
  169. Args:
  170. parameter_broadcast (bool): Parameter broadcast or not.
  171. """
  172. self.check_context_handle()
  173. self._context_handle.set_parameter_broadcast(parameter_broadcast)
  174. def get_parameter_broadcast(self):
  175. """Get parameter broadcast flag."""
  176. self.check_context_handle()
  177. return self._context_handle.get_parameter_broadcast()
  178. def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file):
  179. """
  180. Set strategy checkpoint load path.
  181. Args:
  182. strategy_ckpt_load_file (bool): Path to load parallel strategy checkpoint.
  183. """
  184. self.check_context_handle()
  185. self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)
  186. def get_strategy_ckpt_load_file(self):
  187. """Get strategy checkpoint load path."""
  188. self.check_context_handle()
  189. return self._context_handle.get_strategy_ckpt_load_file()
  190. def set_full_batch(self, full_batch):
  191. """
  192. Set whether load full batch on each device.
  193. Args:
  194. full_batch (bool): True if load full batch on each device.
  195. """
  196. self.check_context_handle()
  197. self._context_handle.set_full_batch(full_batch)
  198. def get_full_batch(self):
  199. """Get whether load full batch on each device."""
  200. self.check_context_handle()
  201. return self._context_handle.get_full_batch()
  202. def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
  203. """
  204. Set strategy checkpoint save path.
  205. Args:
  206. strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
  207. """
  208. self.check_context_handle()
  209. self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
  210. def get_strategy_ckpt_save_file(self):
  211. """Get strategy checkpoint save path."""
  212. self.check_context_handle()
  213. return self._context_handle.get_strategy_ckpt_save_file()
  214. def get_parameter_broadcast_is_set(self):
  215. """Get parameter broadcast is set or not."""
  216. self.check_context_handle()
  217. return self._context_handle.get_parameter_broadcast_is_set()
  218. def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"):
  219. """
  220. Set allreduce fusion strategy by parameters indices.
  221. Args:
  222. indices (list): Indices list.
  223. group (str): The hccl communication group.
  224. Raises:
  225. TypeError: If type of indices item is not int.
  226. TypeError: If group is not a python str.
  227. """
  228. self.check_context_handle()
  229. if isinstance(indices, (list)):
  230. for index in indices:
  231. if not isinstance(index, int):
  232. raise TypeError('indices has invalid value')
  233. else:
  234. raise TypeError('indices must be a python list')
  235. if isinstance(group, (str)):
  236. group_len = len(group)
  237. if group_len > _MAX_GROUP_NAME_LEN:
  238. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  239. else:
  240. raise TypeError('Group must be a python str')
  241. self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
  242. if context.get_context("device_target") == "Ascend":
  243. _set_fusion_strategy_by_idx(indices)
  244. def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"):
  245. """
  246. Get allreduce fusion split indices.
  247. Args:
  248. group (str): The hccl communication group.
  249. Returns:
  250. Return split sizes list according to the group.
  251. Raises:
  252. TypeError: If group is not a python str.
  253. """
  254. self.check_context_handle()
  255. if isinstance(group, (str)):
  256. group_len = len(group)
  257. if group_len > _MAX_GROUP_NAME_LEN:
  258. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  259. else:
  260. raise TypeError('Group must be a python str')
  261. return self._context_handle.get_all_reduce_fusion_split_indices(group)
  262. def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"):
  263. """
  264. Set allreduce fusion strategy by parameters data sizes.
  265. Args:
  266. sizes (list): Sizes list.
  267. group (str): The hccl communication group.
  268. Raises:
  269. TypeError: If type of sizes item is not int.
  270. TypeError: If group is not a python str.
  271. """
  272. self.check_context_handle()
  273. if isinstance(sizes, (list)):
  274. for size in sizes:
  275. if not isinstance(size, int):
  276. raise TypeError('sizes has invalid value')
  277. else:
  278. raise TypeError('sizes must be a python list')
  279. if isinstance(group, (str)):
  280. group_len = len(group)
  281. if group_len > _MAX_GROUP_NAME_LEN:
  282. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  283. else:
  284. raise TypeError('Group must be a python str')
  285. self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
  286. if context.get_context("device_target") == "Ascend":
  287. _set_fusion_strategy_by_size(sizes)
  288. def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"):
  289. """
  290. Get allreduce fusion split sizes.
  291. Args:
  292. group (str): The hccl communication group.
  293. Returns:
  294. Return split sizes list according to the group.
  295. Raises:
  296. TypeError: If group is not a python str.
  297. """
  298. self.check_context_handle()
  299. if isinstance(group, (str)):
  300. group_len = len(group)
  301. if group_len > _MAX_GROUP_NAME_LEN:
  302. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  303. else:
  304. raise TypeError('Group must be a python str')
  305. return self._context_handle.get_all_reduce_fusion_split_sizes(group)
  306. def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
  307. """
  308. Set enable/disable all reduce fusion.
  309. Args:
  310. enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
  311. """
  312. self.check_context_handle()
  313. if not isinstance(enable_all_reduce_fusion, bool):
  314. raise TypeError('enable_all_reduce_fusion is invalid type')
  315. self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
  316. def get_enable_all_reduce_fusion(self):
  317. """Get all reduce fusion flag."""
  318. self.check_context_handle()
  319. return self._context_handle.get_enable_all_reduce_fusion()
  320. def get_device_num_is_set(self):
  321. """Get device number is set or not."""
  322. self.check_context_handle()
  323. return self._context_handle.get_device_num_is_set()
  324. def get_global_rank_is_set(self):
  325. """Get global rank is set or not."""
  326. self.check_context_handle()
  327. return self._context_handle.get_global_rank_is_set()
  328. def reset(self):
  329. """Reset all settings."""
  330. self.check_context_handle()
  331. self._context_handle.reset()
  332. _auto_parallel_context = None
  333. def auto_parallel_context():
  334. """
  335. Get the global _auto_parallel_context, if it is not created, create a new one.
  336. Returns:
  337. _AutoParallelContext, the global auto parallel context.
  338. """
  339. global _auto_parallel_context
  340. if _auto_parallel_context is None:
  341. _auto_parallel_context = _AutoParallelContext()
  342. return _auto_parallel_context
  343. _set_auto_parallel_context_func_map = {
  344. "device_num": auto_parallel_context().set_device_num,
  345. "global_rank": auto_parallel_context().set_global_rank,
  346. "mirror_mean": auto_parallel_context().set_mirror_mean,
  347. "cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
  348. "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
  349. "parallel_mode": auto_parallel_context().set_parallel_mode,
  350. "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
  351. "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
  352. "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
  353. "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
  354. "full_batch": auto_parallel_context().set_full_batch}
  355. _get_auto_parallel_context_func_map = {
  356. "device_num": auto_parallel_context().get_device_num,
  357. "global_rank": auto_parallel_context().get_global_rank,
  358. "mirror_mean": auto_parallel_context().get_mirror_mean,
  359. "cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
  360. "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
  361. "parallel_mode": auto_parallel_context().get_parallel_mode,
  362. "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
  363. "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
  364. "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
  365. "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
  366. "full_batch": auto_parallel_context().get_full_batch}
  367. @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
  368. loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
  369. parameter_broadcast=bool, strategy_ckpt_load_file=str,
  370. strategy_ckpt_save_file=str, full_batch=bool)
  371. def _set_auto_parallel_context(**kwargs):
  372. """
  373. Set auto parallel context.
  374. Note:
  375. Attribute name is required for setting attributes.
  376. Args:
  377. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
  378. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
  379. mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
  380. loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
  381. calculations. Default: True.
  382. cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True.
  383. parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
  384. "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
  385. - stand_alone: Only one processor working.
  386. - data_parallel: Distributing the data across different processors.
  387. - hybrid_parallel: Achieving data parallelism and model parallelism manually.
  388. - semi_auto_parallel: Achieving data parallelism and model parallelism by
  389. setting parallel strategies.
  390. - auto_parallel: Achieving parallelism automatically.
  391. auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
  392. and "dynamic_programming". Default: "dynamic_programming".
  393. - recursive_programming: Recursive programming search mode.
  394. - dynamic_programming: Dynamic programming search mode.
  395. parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
  396. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
  397. broadcast. Default: False.
  398. strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
  399. strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
  400. full_batch (bool): Whether to load the whole batch on each device. Default: False.
  401. Raises:
  402. ValueError: If input key is not attribute in auto parallel context.
  403. """
  404. for key, value in kwargs.items():
  405. if key not in _set_auto_parallel_context_func_map:
  406. raise ValueError("Set context keyword %s is not recognized!" % key)
  407. set_func = _set_auto_parallel_context_func_map[key]
  408. set_func(value)
  409. def _get_auto_parallel_context(attr_key):
  410. """
  411. Get auto parallel context attribute value according to the key.
  412. Args:
  413. attr_key (str): The key of the attribute.
  414. Returns:
  415. Return attribute value according to the key.
  416. Raises:
  417. ValueError: If input key is not attribute in auto parallel context.
  418. """
  419. if attr_key not in _get_auto_parallel_context_func_map:
  420. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  421. get_func = _get_auto_parallel_context_func_map[attr_key]
  422. return get_func()
  423. def _reset_auto_parallel_context():
  424. """
  425. Reset auto parallel context attributes to the default values:
  426. - device_num: 1.
  427. - global_rank: 0.
  428. - mirror_mean: False.
  429. - cast_before_mirror: True.
  430. - parallel_mode: "stand_alone".
  431. - parameter_broadcast: False.
  432. - strategy_ckpt_load_file: ""
  433. - strategy_ckpt_save_file: ""
  434. """
  435. auto_parallel_context().reset()