You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_auto_parallel_context.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Context of auto parallel"""
  16. import threading
  17. import mindspore.context as context
  18. from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
  19. from mindspore._c_expression import AutoParallelContext
  20. from mindspore._checkparam import args_type_check
  21. class _AutoParallelContext:
  22. """
  23. _AutoParallelContext is the environment in which operations are executed
  24. Note:
  25. Create a context through instantiating Context object is not recommended.
  26. Should use auto_parallel_context() to get the context since Context is singleton.
  27. """
  28. _instance = None
  29. _instance_lock = threading.Lock()
  30. def __init__(self):
  31. self._context_handle = AutoParallelContext.get_instance()
  32. def __new__(cls):
  33. if cls._instance is None:
  34. cls._instance_lock.acquire()
  35. cls._instance = object.__new__(cls)
  36. cls._instance_lock.release()
  37. return cls._instance
  38. def check_context_handle(self):
  39. """
  40. Check context handle.
  41. Raises:
  42. ValueError: If the context handle is none.
  43. """
  44. if self._context_handle is None:
  45. raise ValueError("Context handle is none in context!!!")
  46. def set_device_num(self, device_num):
  47. """
  48. Set device num for auto parallel.
  49. Args:
  50. device_num (int): The device number.
  51. Raises:
  52. ValueError: If the device num is not in [1, 4096].
  53. """
  54. self.check_context_handle()
  55. if device_num < 1 or device_num > 4096:
  56. raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
  57. self._context_handle.set_device_num(device_num)
  58. def get_device_num(self):
  59. """Get device num."""
  60. self.check_context_handle()
  61. return self._context_handle.get_device_num()
  62. def set_global_rank(self, global_rank):
  63. """
  64. Set global rank for auto parallel.
  65. Args:
  66. global_rank (int): The rank id of current rank.
  67. Raises:
  68. ValueError: If the global rank is not in [1, 4096].
  69. """
  70. self.check_context_handle()
  71. if global_rank < 0 or global_rank > 4095:
  72. raise ValueError("Global rank must be in [0, 4095], but got {}".format(global_rank))
  73. self._context_handle.set_global_rank(global_rank)
  74. def get_global_rank(self):
  75. """Get current rank id."""
  76. self.check_context_handle()
  77. return self._context_handle.get_global_rank()
  78. def set_mirror_mean(self, mirror_mean):
  79. """
  80. Set mirror_mean flag.
  81. Note:
  82. If mirror_mean is true, it will insert a div operator after parameter gradients allreduce.
  83. Args:
  84. mirror_mean (bool): The mirror_mean flag.
  85. """
  86. self.check_context_handle()
  87. self._context_handle.set_mirror_mean(mirror_mean)
  88. def get_mirror_mean(self):
  89. """Get mirror_mean flag."""
  90. self.check_context_handle()
  91. return self._context_handle.get_mirror_mean()
  92. def set_cast_before_mirror(self, cast_before_mirror):
  93. """
  94. Set cast_before_mirror.
  95. Note:
  96. If cast_before_mirror is true,
  97. it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
  98. Args:
  99. cast_before_mirror (bool): The cast_before_mirror flag.
  100. """
  101. self.check_context_handle()
  102. self._context_handle.set_cast_before_mirror(cast_before_mirror)
  103. def get_cast_before_mirror(self):
  104. """Get cast_before_mirror flag."""
  105. self.check_context_handle()
  106. return self._context_handle.get_cast_before_mirror()
  107. def set_loss_repeated_mean(self, loss_repeated_mean):
  108. """
  109. Set loss_repeated_mean flag.
  110. Note:
  111. If loss_repeated_mean is true,
  112. Distributed automatic differentiation will perform a mean operator
  113. in backward in the case of repeated calculations.
  114. Args:
  115. loss_repeated_mean (bool): The loss_repeated_mean flag.
  116. """
  117. self.check_context_handle()
  118. self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
  119. def get_loss_repeated_mean(self):
  120. """Get loss_repeated_mean flag."""
  121. self.check_context_handle()
  122. return self._context_handle.get_loss_repeated_mean()
  123. def set_communication_backend(self, communication_backend):
  124. """
  125. Set communication backend.
  126. Args:
  127. communication_backend (str): The communication backend.
  128. """
  129. self.check_context_handle()
  130. self._context_handle.set_communication_backend(communication_backend)
  131. def get_communication_backend(self):
  132. """Get communication backend."""
  133. self.check_context_handle()
  134. return self._context_handle.get_communication_backend()
  135. def set_parallel_mode(self, parallel_mode):
  136. """
  137. Set parallel mode for auto parallel.
  138. Args:
  139. parallel_mode (str): The parallel mode of auto parallel.
  140. Raises:
  141. ValueError: If parallel mode is not supported.
  142. """
  143. self.check_context_handle()
  144. ret = self._context_handle.set_parallel_mode(parallel_mode)
  145. if ret is False:
  146. raise ValueError("Parallel mode does not support {}".format(parallel_mode))
  147. def get_parallel_mode(self):
  148. """Get parallel mode."""
  149. self.check_context_handle()
  150. return self._context_handle.get_parallel_mode()
  151. def set_strategy_search_mode(self, strategy_search_mode):
  152. self.check_context_handle()
  153. ret = self._context_handle.set_strategy_search_mode(strategy_search_mode)
  154. if ret is False:
  155. raise ValueError("Strategy search mode does not support {}".format(strategy_search_mode))
  156. def get_strategy_search_mode(self):
  157. self.check_context_handle()
  158. return self._context_handle.get_strategy_search_mode()
  159. def set_parameter_broadcast(self, parameter_broadcast):
  160. """
  161. Set parameter broadcast.
  162. Args:
  163. parameter_broadcast (bool): Parameter broadcast or not.
  164. """
  165. self.check_context_handle()
  166. self._context_handle.set_parameter_broadcast(parameter_broadcast)
  167. def get_parameter_broadcast(self):
  168. """Get parameter broadcast flag."""
  169. self.check_context_handle()
  170. return self._context_handle.get_parameter_broadcast()
  171. def get_parameter_broadcast_is_set(self):
  172. """Get parameter broadcast is set or not."""
  173. self.check_context_handle()
  174. return self._context_handle.get_parameter_broadcast_is_set()
  175. def set_all_reduce_fusion_split_indices(self, indices):
  176. """
  177. Set allreduce fusion strategy by parameters indices.
  178. Args:
  179. indices (list): Indices list.
  180. Raises:
  181. TypeError: If type of indices item is not int.
  182. """
  183. self.check_context_handle()
  184. for index in indices:
  185. if not isinstance(index, int):
  186. raise TypeError('indices has invalid value')
  187. self._context_handle.set_all_reduce_fusion_split_indices(indices)
  188. if context.get_context("device_target") == "Ascend":
  189. _set_fusion_strategy_by_idx(indices)
  190. def get_all_reduce_fusion_split_indices(self):
  191. """Get allreduce fusion split indices."""
  192. self.check_context_handle()
  193. return self._context_handle.get_all_reduce_fusion_split_indices()
  194. def set_all_reduce_fusion_split_sizes(self, sizes):
  195. """
  196. Set allreduce fusion strategy by parameters data sizes.
  197. Args:
  198. sizes (list): Sizes list.
  199. Raises:
  200. TypeError: If type of sizes item is not int.
  201. """
  202. self.check_context_handle()
  203. for size in sizes:
  204. if not isinstance(size, int):
  205. raise TypeError('sizes has invalid value')
  206. self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
  207. if context.get_context("device_target") == "Ascend":
  208. _set_fusion_strategy_by_size(sizes)
  209. def get_all_reduce_fusion_split_sizes(self):
  210. """Get allreduce fusion split sizes."""
  211. self.check_context_handle()
  212. return self._context_handle.get_all_reduce_fusion_split_sizes()
  213. def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
  214. """
  215. Set enable/disable all reduce fusion.
  216. Args:
  217. enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
  218. """
  219. self.check_context_handle()
  220. if not isinstance(enable_all_reduce_fusion, bool):
  221. raise TypeError('enable_all_reduce_fusion is invalid type')
  222. self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
  223. def get_enable_all_reduce_fusion(self):
  224. """Get all reduce fusion flag."""
  225. self.check_context_handle()
  226. return self._context_handle.get_enable_all_reduce_fusion()
  227. def get_device_num_is_set(self):
  228. """Get device number is set or not."""
  229. self.check_context_handle()
  230. return self._context_handle.get_device_num_is_set()
  231. def get_global_rank_is_set(self):
  232. """Get global rank is set or not."""
  233. self.check_context_handle()
  234. return self._context_handle.get_global_rank_is_set()
  235. def reset(self):
  236. """Reset all settings."""
  237. self.check_context_handle()
  238. self._context_handle.reset()
  239. _auto_parallel_context = None
  240. def auto_parallel_context():
  241. """
  242. Get the global _auto_parallel_context, if it is not created, create a new one.
  243. Returns:
  244. _AutoParallelContext, the global auto parallel context.
  245. """
  246. global _auto_parallel_context
  247. if _auto_parallel_context is None:
  248. _auto_parallel_context = _AutoParallelContext()
  249. return _auto_parallel_context
  250. _set_auto_parallel_context_func_map = {
  251. "device_num": auto_parallel_context().set_device_num,
  252. "global_rank": auto_parallel_context().set_global_rank,
  253. "mirror_mean": auto_parallel_context().set_mirror_mean,
  254. "cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
  255. "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
  256. "parallel_mode": auto_parallel_context().set_parallel_mode,
  257. "parameter_broadcast": auto_parallel_context().set_parameter_broadcast}
  258. _get_auto_parallel_context_func_map = {
  259. "device_num": auto_parallel_context().get_device_num,
  260. "global_rank": auto_parallel_context().get_global_rank,
  261. "mirror_mean": auto_parallel_context().get_mirror_mean,
  262. "cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
  263. "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
  264. "parallel_mode": auto_parallel_context().get_parallel_mode,
  265. "parameter_broadcast": auto_parallel_context().get_parameter_broadcast}
  266. @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
  267. loss_repeated_mean=bool, parallel_mode=str, parameter_broadcast=bool)
  268. def _set_auto_parallel_context(**kwargs):
  269. """
  270. Set auto parallel context.
  271. Note:
  272. Attribute name is required for setting attributes.
  273. Args:
  274. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
  275. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
  276. mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
  277. loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
  278. calculations. Default: True.
  279. cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True.
  280. parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
  281. "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
  282. - stand_alone: Only one processor working.
  283. - data_parallel: Distributing the data across different processors.
  284. - hybrid_parallel: Achieving data parallelism and model parallelism manually.
  285. - semi_auto_parallel: Achieving data parallelism and model parallelism by
  286. setting parallel strategies.
  287. - auto_parallel: Achieving parallelism automatically.
  288. parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
  289. "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
  290. broadcast. Default: False.
  291. Raises:
  292. ValueError: If input key is not attribute in auto parallel context.
  293. """
  294. for key, value in kwargs.items():
  295. if key not in _set_auto_parallel_context_func_map:
  296. raise ValueError("Set context keyword %s is not recognized!" % key)
  297. set_func = _set_auto_parallel_context_func_map[key]
  298. set_func(value)
  299. def _get_auto_parallel_context(attr_key):
  300. """
  301. Get auto parallel context attribute value according to the key.
  302. Args:
  303. attr_key (str): The key of the attribute.
  304. Returns:
  305. Return attribute value according to the key.
  306. Raises:
  307. ValueError: If input key is not attribute in auto parallel context.
  308. """
  309. if attr_key not in _get_auto_parallel_context_func_map:
  310. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  311. get_func = _get_auto_parallel_context_func_map[attr_key]
  312. return get_func()
  313. def _reset_auto_parallel_context():
  314. """
  315. Reset auto parallel context attributes to the default values:
  316. - device_num: 1.
  317. - global_rank: 0.
  318. - mirror_mean: False.
  319. - cast_before_mirror: True.
  320. - parallel_mode: "stand_alone".
  321. - parameter_broadcast: False.
  322. """
  323. auto_parallel_context().reset()