You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_comm_helper.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """comm_helper"""
  16. from ._hccl_management import load_lib as hccl_load_lib
  17. _HCCL_AVAILABLE = False
  18. _NCCL_AVAILABLE = False
  19. try:
  20. import mindspore._ms_mpi as mpi
  21. _NCCL_AVAILABLE = True
  22. except ImportError:
  23. _NCCL_AVAILABLE = False
  24. try:
  25. hccl_load_lib()
  26. _HCCL_AVAILABLE = True
  27. except RuntimeError:
  28. _HCCL_AVAILABLE = False
  29. if _HCCL_AVAILABLE:
  30. from . import _hccl_management as hccl
  31. else:
  32. try:
  33. import hccl_test.manage.api as hccl
  34. _HCCL_AVAILABLE = True
  35. except ImportError:
  36. _HCCL_AVAILABLE = False
  37. HCCL_WORLD_COMM_GROUP = "hccl_world_group"
  38. NCCL_WORLD_COMM_GROUP = "nccl_world_group"
  39. class Backend:
  40. """
  41. Class for available backends.
  42. Note:
  43. The backends' value should be string, e.g., "hccl".
  44. If backend is set to Backend.UNDEFINED, it will be seen as invaliad.
  45. Args:
  46. name (str): The name of backend.
  47. Raises:
  48. TypeError: If name is not a string.
  49. ValueError: If backend is invalid.
  50. Examples:
  51. >>> Backend("abc")
  52. >>> hccl = Backend("hccl")
  53. """
  54. UNDEFINED = "undefined"
  55. HCCL = "hccl"
  56. NCCL = "nccl"
  57. def __new__(cls, name):
  58. """Create instance object of Backend."""
  59. if not isinstance(name, str):
  60. raise TypeError("Backend name must be a string, but got {}".format(type(name)))
  61. value = getattr(Backend, name.upper(), Backend.UNDEFINED)
  62. if value == Backend.UNDEFINED:
  63. raise ValueError("Invalid backend: '{}'".format(name))
  64. return value
  65. def is_hccl_available():
  66. """
  67. Check hccl api is available.
  68. Returns:
  69. Boolean. Return whether hccl is available or not.
  70. """
  71. return _HCCL_AVAILABLE
  72. def is_nccl_available():
  73. """
  74. Check nccl api is available.
  75. Returns:
  76. Boolean. Return whether nccl is available or not.
  77. """
  78. return _NCCL_AVAILABLE
  79. def check_parameter_available(func):
  80. """
  81. Check parameter is available. If not available, raise Error.
  82. Args:
  83. func (Function): The function to be run.
  84. Raises:
  85. RuntimeError.
  86. Returns:
  87. Wrapper. If not available, raise Error.
  88. """
  89. def wrapper(*args, **kargs):
  90. group = None
  91. if "group" in kargs.keys():
  92. group = kargs.get("group")
  93. if group is not None and not isinstance(group, str):
  94. raise TypeError("Group should be str or None, "
  95. "but got group {}".format(type(group)))
  96. if "backend" in kargs.keys():
  97. backend = kargs.get("backend")
  98. if backend is Backend.HCCL and not is_hccl_available():
  99. raise RuntimeError("Distributed Communication doesn't have HCCL built in")
  100. if backend is Backend.NCCL and not is_nccl_available():
  101. raise RuntimeError("Distributed Communication doesn't have NCCL built in")
  102. if group is None:
  103. if backend is Backend.HCCL:
  104. group = HCCL_WORLD_COMM_GROUP
  105. elif backend is Backend.NCCL:
  106. group = NCCL_WORLD_COMM_GROUP
  107. return func(*args, **kargs)
  108. return wrapper
  109. @check_parameter_available
  110. def _get_rank_helper(group, backend):
  111. """
  112. The Helper to do get_rank_id.
  113. Args:
  114. group (str): The communication group.
  115. backend (str): The backend, like "hccl".
  116. Raises:
  117. ValueError: If backend is invalid.
  118. Returns:
  119. Integer. The local rank id of the calling process.
  120. """
  121. rank_id = None
  122. if backend == Backend.HCCL:
  123. if group == HCCL_WORLD_COMM_GROUP:
  124. rank_id = hccl.get_rank_id()
  125. else:
  126. rank_id = hccl.get_rank_id(group)
  127. elif backend == Backend.NCCL:
  128. if group == NCCL_WORLD_COMM_GROUP:
  129. rank_id = mpi.get_rank_id()
  130. else:
  131. raise RuntimeError("Nccl doesn't support get_rank_id by user group now.")
  132. else:
  133. raise ValueError("Invalid backend: '{}'".format(backend))
  134. return rank_id
  135. @check_parameter_available
  136. def _get_local_rank_helper(group, backend):
  137. """
  138. The Helper to do get_local_rank_id.
  139. Args:
  140. group (str): The communication group.
  141. backend (str): The backend, like "hccl".
  142. Raises:
  143. ValueError: If backend is invalid.
  144. Returns:
  145. Integer. The local rank id of the calling process.
  146. """
  147. rank_id = None
  148. if backend == Backend.HCCL:
  149. if group == HCCL_WORLD_COMM_GROUP:
  150. rank_id = hccl.get_local_rank_id()
  151. else:
  152. rank_id = hccl.get_local_rank_id(group)
  153. elif backend == Backend.NCCL:
  154. raise RuntimeError("Nccl doesn't support get_local_rank_id now.")
  155. else:
  156. raise ValueError("Invalid backend: '{}'".format(backend))
  157. return rank_id
  158. @check_parameter_available
  159. def _get_size_helper(group, backend):
  160. """
  161. The Helper to do get_rank_size.
  162. Args:
  163. group (str): The communication group.
  164. backend (str): The backend, like "hccl".
  165. Raises:
  166. ValueError: If backend is invalid.
  167. Returns:
  168. Integer. The rank size of specified group.
  169. """
  170. size = None
  171. if backend == Backend.HCCL:
  172. if group == HCCL_WORLD_COMM_GROUP:
  173. size = hccl.get_rank_size()
  174. else:
  175. size = hccl.get_rank_size(group)
  176. elif backend == Backend.NCCL:
  177. if group == NCCL_WORLD_COMM_GROUP:
  178. size = mpi.get_rank_size()
  179. else:
  180. raise RuntimeError("Nccl doesn't support get_rank_size by user group now.")
  181. else:
  182. raise ValueError("Invalid backend: '{}'".format(backend))
  183. return size
  184. @check_parameter_available
  185. def _get_local_size_helper(group, backend):
  186. """
  187. The Helper to do get_local_rank_size.
  188. Args:
  189. group (str): The communication group.
  190. backend (str): The backend, like "hccl".
  191. Raises:
  192. ValueError: If backend is invalid.
  193. Returns:
  194. Integer. The local rank size where the calling process is being within specified group.
  195. """
  196. size = None
  197. if backend == Backend.HCCL:
  198. if group == HCCL_WORLD_COMM_GROUP:
  199. size = hccl.get_local_rank_size()
  200. else:
  201. size = hccl.get_local_rank_size(group)
  202. elif backend == Backend.NCCL:
  203. raise RuntimeError("Nccl doesn't support get_local_rank_size now.")
  204. else:
  205. raise ValueError("Invalid backend: '{}'".format(backend))
  206. return size
  207. @check_parameter_available
  208. def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
  209. """
  210. The Helper to do get_world_rank_from_group_rank.
  211. Args:
  212. group (str): The user communication group.
  213. group_rank_id (int): A rank id in user communication group.
  214. backend (str): The backend, like "hccl".
  215. Raises:
  216. TypeError: If group_rank_id is not int.
  217. ValueError: If group is "hccl_world_group" or backend is invalid.
  218. Returns:
  219. Integer. A rank id in world communication group.
  220. """
  221. world_rank_id = None
  222. if not isinstance(group_rank_id, int):
  223. raise TypeError("group_rank_id should be int, but got type {}".format(type(group_rank_id)))
  224. if backend == Backend.HCCL:
  225. if group == HCCL_WORLD_COMM_GROUP:
  226. raise ValueError("Group cannot be 'hccl_world_group'. ")
  227. world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id)
  228. elif backend == Backend.NCCL:
  229. raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.")
  230. else:
  231. raise ValueError("Invalid backend: '{}'".format(backend))
  232. return world_rank_id
  233. @check_parameter_available
  234. def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
  235. """
  236. The Helper to do get_group_rank_from_world_rank.
  237. Args:
  238. world_rank_id (int): A rank id in world communication group.
  239. group (str): The user communication group.
  240. backend (str): The backend, like "hccl".
  241. Raises:
  242. TypeError: If world_rank_id is not int.
  243. ValueError: If group is 'hccl_world_group' or backend is invalid.
  244. Returns:
  245. Integer. A rank id in user communication group.
  246. """
  247. group_rank_id = None
  248. if not isinstance(world_rank_id, int):
  249. raise TypeError("world_rank_id should be int, but got type {}".format(type(world_rank_id)))
  250. if backend == Backend.HCCL:
  251. if group == HCCL_WORLD_COMM_GROUP:
  252. raise ValueError("Group cannot be 'hccl_world_group'. ")
  253. group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group)
  254. elif backend == Backend.NCCL:
  255. raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.")
  256. else:
  257. raise ValueError("Invalid backend: '{}'".format(backend))
  258. return group_rank_id
  259. @check_parameter_available
  260. def _create_group_helper(group, rank_ids, backend):
  261. """
  262. The Helper to do create_group.
  263. Args:
  264. group (str): The communication group.
  265. rank_ids (list): Rank ids in the group.
  266. backend (str): The backend, like "hccl".
  267. Raises:
  268. TypeError: If rank_ids is not a list.
  269. ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
  270. """
  271. if backend == Backend.HCCL:
  272. if not isinstance(rank_ids, list):
  273. raise TypeError("Rank_ids {} should be list".format(rank_ids))
  274. rank_size = len(rank_ids)
  275. if rank_size < 1:
  276. raise ValueError("Rank_ids size {} should be large than 0".format(rank_size))
  277. if len(rank_ids) - len(list(set(rank_ids))) > 0:
  278. raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
  279. hccl.create_group(group, rank_size, rank_ids)
  280. elif backend == Backend.NCCL:
  281. raise RuntimeError("Nccl doesn't support create_group now.")
  282. else:
  283. raise ValueError("Invalid backend: '{}'".format(backend))
  284. @check_parameter_available
  285. def _destroy_group_helper(group, backend):
  286. """
  287. The Helper to do destroy_group.
  288. Args:
  289. group (str): The user communication group.
  290. backend (str): The backend, like "hccl".
  291. Raises:
  292. ValueError: If group is "hccl_world_group" or backend is invalid.
  293. """
  294. if backend == Backend.HCCL:
  295. if group == HCCL_WORLD_COMM_GROUP:
  296. raise ValueError("The hccl_world_group does not support destruction.")
  297. hccl.destroy_group(group)
  298. elif backend == Backend.NCCL:
  299. raise RuntimeError("Nccl doesn't support destroy_group now.")
  300. else:
  301. raise ValueError("Invalid backend: '{}'".format(backend))