Browse Source

fix example error of HostReduceScatter

tags/v0.5.0-beta
Yi Huaijie 5 years ago
parent
commit
c9395e07b5
1 changed files with 5 additions and 6 deletions
  1. +5
    -6
      mindspore/ops/operations/comm_ops.py

+ 5
- 6
mindspore/ops/operations/comm_ops.py View File

@@ -191,8 +191,7 @@ class HostAllGather(PrimitiveWithInfer):

Raises:
TypeError: If group is not a list nor tuple, or elements of group are not int.
ValueError: If the local rank id of the calling process not in group,
or rank_id from group not in [0, 7].
ValueError: If group is not set, or rank_id from group not in [0, 7].

Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
@@ -281,7 +280,7 @@ class ReduceScatter(PrimitiveWithInfer):
>>> def construct(self, x):
>>> return self.reducescatter(x)
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@@ -327,8 +326,8 @@ class HostReduceScatter(PrimitiveWithInfer):
Raises:
TypeError: If op is not a string and group is not a list nor tuple,
or elements of group are not int.
ValueError: If the first dimension of input can not be divided by rank size,
or group is not set, or rank_id not in [1, 7].
ValueError: If the first dimension of input can not be divided by group size,
or group is not set, or rank_id not in [0, 7].

Examples:
>>> import mindspore.nn as nn
@@ -348,7 +347,7 @@ class HostReduceScatter(PrimitiveWithInfer):
>>> def construct(self, x):
>>> return self.hostreducescatter(x)
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""


Loading…
Cancel
Save