You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comm_ops.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """comm_ops"""
  16. from mindspore.common import Tensor
  17. from ..._checkparam import Validator as validator
  18. from ..._checkparam import Rel
  19. from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
  20. from ...common import dtype as mstype
  21. from ..primitive import PrimitiveWithInfer, prim_attr_register
  22. class ReduceOp:
  23. """
  24. Operation options for reduce tensors.
  25. There are four kinds of operation options, "SUM", "MAX", "MIN", and "PROD".
  26. - SUM: Take the sum.
  27. - MAX: Take the maximum.
  28. - MIN: Take the minimum.
  29. - PROD: Take the product.
  30. """
  31. SUM = "sum"
  32. MAX = "max"
  33. MIN = "min"
  34. PROD = "prod"
  35. target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
  36. class AllReduce(PrimitiveWithInfer):
  37. """
  38. Reduces the tensor data across all devices in such a way that all devices will get the same final result.
  39. Note:
  40. The operation of AllReduce does not support "prod" currently.
  41. The tensors must have the same shape and format in all processes of the collection.
  42. Args:
  43. op (str): Specifies an operation used for element-wise reductions,
  44. like sum, max, and min. Default: ReduceOp.SUM.
  45. group (str): The communication group to work on. Default: "hccl_world_group".
  46. Raises:
  47. TypeError: If any of operation and group is not a string,
  48. or fusion is not an integer, or the input's dtype is bool.
  49. ValueError: If the operation is "prod".
  50. Inputs:
  51. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  52. Outputs:
  53. Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
  54. The contents depend on the specified operation.
  55. Examples:
  56. >>> from mindspore.communication import init
  57. >>> from mindspore import Tensor
  58. >>> from mindspore.ops.operations.comm_ops import ReduceOp
  59. >>> import mindspore.nn as nn
  60. >>> import mindspore.ops.operations as P
  61. >>>
  62. >>> init()
  63. >>> class Net(nn.Cell):
  64. >>> def __init__(self):
  65. >>> super(Net, self).__init__()
  66. >>> self.allreduce_sum = P.AllReduce(ReduceOp.SUM, group="nccl_world_group")
  67. >>>
  68. >>> def construct(self, x):
  69. >>> return self.allreduce_sum(x)
  70. >>>
  71. >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
  72. >>> net = Net()
  73. >>> output = net(input_)
  74. """
  75. @prim_attr_register
  76. def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
  77. if not isinstance(op, type(ReduceOp.SUM)):
  78. raise TypeError("The operation of AllReduce should be str.")
  79. if op == ReduceOp.PROD:
  80. raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.")
  81. if not isinstance(_get_group(group), str):
  82. raise TypeError("The group of AllReduce should be str.")
  83. self.op = op
  84. self.add_prim_attr('group', _get_group(group))
  85. self.add_prim_attr('fusion', 0)
  86. self.add_prim_attr('index', 0)
  87. def infer_shape(self, x_shape):
  88. return x_shape
  89. def infer_dtype(self, x_dtype):
  90. validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
  91. return x_dtype
  92. class AllGather(PrimitiveWithInfer):
  93. """
  94. Gathers tensors from the specified communication group.
  95. Note:
  96. The tensors must have the same shape and format in all processes of the collection.
  97. Args:
  98. group (str): The communication group to work on. Default: "hccl_world_group".
  99. Raises:
  100. TypeError: If group is not a string.
  101. ValueError: If the local rank id of the calling process in the group
  102. is larger than the group's rank size.
  103. Inputs:
  104. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  105. Outputs:
  106. Tensor. If the number of devices in the group is N,
  107. then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
  108. Examples:
  109. >>> import mindspore.ops.operations as P
  110. >>> import mindspore.nn as nn
  111. >>> from mindspore.communication import init
  112. >>> from mindspore import Tensor
  113. >>>
  114. >>> init()
  115. >>> class Net(nn.Cell):
  116. >>> def __init__(self):
  117. >>> super(Net, self).__init__()
  118. >>> self.allgather = P.AllGather(group="nccl_world_group")
  119. >>>
  120. >>> def construct(self, x):
  121. >>> return self.allgather(x)
  122. >>>
  123. >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
  124. >>> net = Net()
  125. >>> output = net(input_)
  126. """
  127. @prim_attr_register
  128. def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
  129. validator.check_value_type('group', _get_group(group), (str,), self.name)
  130. self.rank = get_rank(_get_group(group))
  131. self.rank_size = get_group_size(_get_group(group))
  132. validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
  133. self.add_prim_attr('rank_size', self.rank_size)
  134. self.add_prim_attr('group', _get_group(group))
  135. self.add_prim_attr('fusion', 0)
  136. def infer_shape(self, x_shape):
  137. validator.check_positive_int(len(x_shape), "x shape", self.name)
  138. x_shape[0] = x_shape[0] * self.rank_size
  139. return x_shape
  140. def infer_dtype(self, x_dtype):
  141. validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
  142. return x_dtype
  143. def __call__(self, tensor):
  144. raise NotImplementedError
  145. class _HostAllGather(PrimitiveWithInfer):
  146. """
  147. Gathers tensors from the specified communication group on host.
  148. Note:
  149. The tensors must have the same shape and format in all processes of the collection.
  150. _HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on
  151. to enable it. Using mpirun command to run it:
  152. mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
  153. Args:
  154. group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
  155. Raises:
  156. TypeError: If group is not a list nor tuple, or elements of group are not int.
  157. ValueError: If group is not set, or rank_id from group not in [0, 7].
  158. Inputs:
  159. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  160. Outputs:
  161. Tensor. If the number of devices in the group is N,
  162. then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
  163. """
  164. @prim_attr_register
  165. def __init__(self, group=None):
  166. if group is None:
  167. raise ValueError(f"For '{self.name}' group must be set.")
  168. validator.check_value_type('group', group, (tuple, list), self.name)
  169. validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
  170. for r in group:
  171. validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
  172. validator.check_value_type("rank_id", r, (int,), self.name)
  173. self.group_size = len(group)
  174. self.add_prim_attr('group', group)
  175. def infer_shape(self, x_shape):
  176. validator.check_positive_int(len(x_shape), "x shape", self.name)
  177. x_shape[0] = x_shape[0] * self.group_size
  178. return x_shape
  179. def infer_dtype(self, x_dtype):
  180. validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
  181. return x_dtype
  182. def __call__(self, tensor):
  183. raise NotImplementedError
  184. class ReduceScatter(PrimitiveWithInfer):
  185. """
  186. Reduces and scatters tensors from the specified communication group.
  187. Note:
  188. The back propagation of the op is not supported yet. Stay tuned for more.
  189. The tensors must have the same shape and format in all processes of the collection.
  190. Args:
  191. op (str): Specifies an operation used for element-wise reductions,
  192. like SUM, MAX, AVG. Default: ReduceOp.SUM.
  193. group (str): The communication group to work on. Default: "hccl_world_group".
  194. Raises:
  195. TypeError: If any of operation and group is not a string.
  196. ValueError: If the first dimension of the input cannot be divided by the rank size.
  197. Examples:
  198. >>> from mindspore import Tensor
  199. >>> from mindspore.communication import init
  200. >>> from mindspore.ops.operations.comm_ops import ReduceOp
  201. >>> import mindspore.nn as nn
  202. >>> import mindspore.ops.operations as P
  203. >>>
  204. >>> init()
  205. >>> class Net(nn.Cell):
  206. >>> def __init__(self):
  207. >>> super(Net, self).__init__()
  208. >>> self.reducescatter = P.ReduceScatter(ReduceOp.SUM, group="nccl_world_group")
  209. >>>
  210. >>> def construct(self, x):
  211. >>> return self.reducescatter(x)
  212. >>>
  213. >>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
  214. >>> net = Net()
  215. >>> output = net(input_)
  216. """
  217. @prim_attr_register
  218. def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
  219. validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
  220. validator.check_value_type('group', _get_group(group), (str,), self.name)
  221. self.op = op
  222. self.rank_size = get_group_size(_get_group(group))
  223. self.add_prim_attr('rank_size', self.rank_size)
  224. self.add_prim_attr('group', _get_group(group))
  225. self.add_prim_attr('fusion', 0)
  226. def infer_shape(self, x_shape):
  227. if x_shape[0] % self.rank_size != 0:
  228. raise ValueError(f"For '{self.name}' the first dimension of x should be divided by rank_size.")
  229. x_shape[0] = int(x_shape[0]/self.rank_size)
  230. return x_shape
  231. def infer_dtype(self, x_dtype):
  232. validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
  233. return x_dtype
  234. def __call__(self, tensor):
  235. raise NotImplementedError
  236. class _HostReduceScatter(PrimitiveWithInfer):
  237. """
  238. Reduces and scatters tensors from the specified communication group on host.
  239. Note:
  240. The tensors must have the same shape and format in all processes of the collection.
  241. _HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option
  242. -M on to enable it. Using mpirun command to run it:
  243. mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py
  244. Args:
  245. op (str): Specifies an operation used for element-wise reductions,
  246. like sum, max, avg. Default: ReduceOp.SUM.
  247. group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
  248. Raises:
  249. TypeError: If op is not a string and group is not a list nor tuple,
  250. or elements of group are not int.
  251. ValueError: If the first dimension of input can not be divided by group size,
  252. or group is not set, or rank_id not in [0, 7].
  253. """
  254. @prim_attr_register
  255. def __init__(self, op=ReduceOp.SUM, group=None):
  256. if group is None:
  257. raise ValueError(f"For '{self.name}' group must be set.")
  258. validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
  259. validator.check_value_type('group', group, (tuple, list), self.name)
  260. validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
  261. for r in group:
  262. validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
  263. validator.check_value_type("rank_id", r, (int,), self.name)
  264. self.op = op
  265. self.group_size = len(group)
  266. self.add_prim_attr('group', group)
  267. def infer_shape(self, x_shape):
  268. if x_shape[0] % self.group_size != 0:
  269. raise ValueError(f"For '{self.name}' the first dimension of x should be divided by group_size.")
  270. x_shape[0] = int(x_shape[0]/self.group_size)
  271. return x_shape
  272. def infer_dtype(self, x_dtype):
  273. validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
  274. return x_dtype
  275. def __call__(self, tensor):
  276. raise NotImplementedError
  277. class Broadcast(PrimitiveWithInfer):
  278. """
  279. Broadcasts the tensor to the whole group.
  280. Note:
  281. The tensors must have the same shape and format in all processes of the collection.
  282. Args:
  283. root_rank (int): Source rank. Required in all processes except the one
  284. that is sending the data.
  285. group (str): The communication group to work on. Default: "hccl_world_group".
  286. Inputs:
  287. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  288. Outputs:
  289. Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
  290. The contents depend on the data of the `root_rank` device.
  291. Raises:
  292. TypeError: If root_rank is not a integer or group is not a string.
  293. Examples:
  294. >>> from mindspore import Tensor
  295. >>> from mindspore.communication import init
  296. >>> import mindspore.nn as nn
  297. >>> import mindspore.ops.operations as P
  298. >>>
  299. >>> init()
  300. >>> class Net(nn.Cell):
  301. >>> def __init__(self):
  302. >>> super(Net, self).__init__()
  303. >>> self.broadcast = P.Broadcast(1)
  304. >>>
  305. >>> def construct(self, x):
  306. >>> return self.broadcast((x,))
  307. >>>
  308. >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
  309. >>> net = Net()
  310. >>> output = net(input_)
  311. """
  312. @prim_attr_register
  313. def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
  314. validator.check_value_type('root_rank', root_rank, (int,), self.name)
  315. validator.check_value_type('group', _get_group(group), (str,), self.name)
  316. self.add_prim_attr('group', _get_group(group))
  317. def infer_shape(self, x_shape):
  318. return x_shape
  319. def infer_dtype(self, x_dtype):
  320. if not isinstance(x_dtype, tuple):
  321. raise TypeError(f"{self.name}'s input should be a tuple!")
  322. for _ele in x_dtype:
  323. validator.check_tensor_type_same({'x': _ele}, target_dtypes, self.name)
  324. return x_dtype
  325. class _AlltoAll(PrimitiveWithInfer):
  326. """
  327. AlltoAll is a collective operation.
  328. AlltoAll sends data from the all processes to the all processes in the specified group. It has two phases:
  329. - The scatter phase: On each process, the operand is split into split_count number of blocks along the
  330. split_dimensions, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
  331. - The gather phase: Each process concatenates the received blocks along the concat_dimension.
  332. Note:
  333. The tensors must have the same shape and format in all processes of the collection.
  334. Args:
  335. split_count (int): On each process, divide blocks into split_count number.
  336. split_dim (int): On each process, split blocks along the split_dim.
  337. concat_dim (int): On each process, gather the received blocks along the concat_dimension.
  338. group (str): The communication group to work on. Default: "hccl_world_group".
  339. Raises:
  340. TypeError: If group is not a string.
  341. """
  342. @prim_attr_register
  343. def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP):
  344. """Initialize AlltoAll"""
  345. validator.check_value_type('group', _get_group(group), (str,), self.name)
  346. self.split_count = split_count
  347. self.split_dim = split_dim
  348. self.concat_dim = concat_dim
  349. self.add_prim_attr('group', _get_group(group))
  350. def infer_shape(self, x_shape):
  351. x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count
  352. x_shape[self.split_dim] = int(x_shape[self.split_dim] / self.split_count)
  353. return x_shape
  354. def infer_dtype(self, x_dtype):
  355. validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
  356. return x_dtype
  357. def __call__(self, tensor):
  358. return
  359. class _MirrorOperator(PrimitiveWithInfer):
  360. """
  361. Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
  362. internal use of parallel modules and cannot be called by users.
  363. Args:
  364. group (str): The communication group to work on. Default: None.
  365. dev_num (int): The device number of the group. Default: None.
  366. mean_flag (bool): Whether use mean in backward. Default: None.
  367. """
  368. @prim_attr_register
  369. def __init__(self, group=None, dev_num=None, mean_flag=None):
  370. self.group = group
  371. self.dev_num = dev_num
  372. self.mean_flag = mean_flag
  373. def infer_shape(self, x_shape):
  374. return x_shape
  375. def infer_dtype(self, x_dtype):
  376. return x_dtype
  377. mirror = _MirrorOperator()
  378. class _VirtualDiv(PrimitiveWithInfer):
  379. """
  380. Auto parallel virtual operator. Do nothing in forward, do Div in backward.
  381. Args:
  382. divisor: float32
  383. """
  384. @prim_attr_register
  385. def __init__(self, divisor=None):
  386. self.divisor = divisor
  387. def infer_shape(self, x_shape):
  388. return x_shape
  389. def infer_dtype(self, x_dtype):
  390. return x_dtype
  391. virtual_div = _VirtualDiv()
  392. class _VirtualDataset(PrimitiveWithInfer):
  393. """
  394. Auto parallel virtual dataset operator.
  395. It would insert Broadcast operator in forward computation and be deleted before backward computation.
  396. """
  397. @prim_attr_register
  398. def __init__(self):
  399. """init"""
  400. def infer_shape(self, *args):
  401. if len(args) == 1:
  402. return args[0]
  403. return args
  404. def infer_dtype(self, *args):
  405. if len(args) == 1:
  406. return args[0]
  407. return args
  408. virtual_dataset = _VirtualDataset()
  409. class _GetTensorSlice(PrimitiveWithInfer):
  410. """
  411. Gets tensor slice by device matrix and tensor map.
  412. Args:
  413. dev_mat (tuple): The device matrix of the slice tensor.
  414. tensor_map (tuple): The tensor map of the slice tensor.
  415. """
  416. @prim_attr_register
  417. def __init__(self):
  418. """Initialize ChunkTensor"""
  419. def infer_value(self, x, dev_mat, tensor_map):
  420. from mindspore.parallel._tensor import _load_tensor
  421. validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
  422. validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
  423. return Tensor(_load_tensor(x, dev_mat, tensor_map))