|
|
|
@@ -142,16 +142,19 @@ class AllGather(PrimitiveWithInfer): |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> # This example should be run with two devices. Refer to the tutorial > Distirbuted Training on mindspore.cn. |
|
|
|
>>> import numpy as np |
|
|
|
>>> import mindspore.ops.operations as ops |
|
|
|
>>> import mindspore.nn as nn |
|
|
|
>>> from mindspore.communication import init |
|
|
|
>>> from mindspore import Tensor |
|
|
|
>>> from mindspore import Tensor, context |
|
|
|
>>> |
|
|
|
>>> context.set_context(mode=context.GRAPH_MODE) |
|
|
|
>>> init() |
|
|
|
... class Net(nn.Cell): |
|
|
|
... def __init__(self): |
|
|
|
... super(Net, self).__init__() |
|
|
|
... self.allgather = ops.AllGather(group="nccl_world_group") |
|
|
|
... self.allgather = ops.AllGather() |
|
|
|
... |
|
|
|
... def construct(self, x): |
|
|
|
... return self.allgather(x) |
|
|
|
@@ -160,6 +163,10 @@ class AllGather(PrimitiveWithInfer): |
|
|
|
>>> net = Net() |
|
|
|
>>> output = net(input_) |
|
|
|
>>> print(output) |
|
|
|
[[1. 1. 1. 1. 1. 1. 1. 1.] |
|
|
|
[1. 1. 1. 1. 1. 1. 1. 1.] |
|
|
|
[1. 1. 1. 1. 1. 1. 1. 1.] |
|
|
|
[1. 1. 1. 1. 1. 1. 1. 1.]] |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
@@ -255,16 +262,18 @@ class ReduceScatter(PrimitiveWithInfer): |
|
|
|
ValueError: If the first dimension of the input cannot be divided by the rank size. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``GPU`` |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore import Tensor |
|
|
|
>>> # This example should be run with two devices. Refer to the tutorial > Distirbuted Training on mindspore.cn. |
|
|
|
>>> from mindspore import Tensor, context |
|
|
|
>>> from mindspore.communication import init |
|
|
|
>>> from mindspore.ops.operations.comm_ops import ReduceOp |
|
|
|
>>> import mindspore.nn as nn |
|
|
|
>>> import mindspore.ops.operations as ops |
|
|
|
>>> import numpy as np |
|
|
|
>>> |
|
|
|
>>> context.set_context(mode=context.GRAPH_MODE) |
|
|
|
>>> init() |
|
|
|
>>> class Net(nn.Cell): |
|
|
|
... def __init__(self): |
|
|
|
@@ -278,6 +287,10 @@ class ReduceScatter(PrimitiveWithInfer): |
|
|
|
>>> net = Net() |
|
|
|
>>> output = net(input_) |
|
|
|
>>> print(output) |
|
|
|
[[2. 2. 2. 2. 2. 2. 2. 2.] |
|
|
|
[2. 2. 2. 2. 2. 2. 2. 2.] |
|
|
|
[2. 2. 2. 2. 2. 2. 2. 2.] |
|
|
|
[2. 2. 2. 2. 2. 2. 2. 2.]] |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
|