|
|
|
@@ -66,11 +66,12 @@ class AllReduce(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore.communication.management import init |
|
|
|
>>> import mindspore.ops.operations as P |
|
|
|
>>> init('nccl') |
|
|
|
>>> class Net(nn.Cell): |
|
|
|
>>> def __init__(self): |
|
|
|
>>> super(Net, self).__init__() |
|
|
|
>>> self.allreduce_sum = AllReduce(ReduceOp.SUM, group="nccl_world_group") |
|
|
|
>>> self.allreduce_sum = P.AllReduce(ReduceOp.SUM, group="nccl_world_group") |
|
|
|
>>> |
|
|
|
>>> def construct(self, x): |
|
|
|
>>> return self.allreduce_sum(x) |
|
|
|
@@ -130,11 +131,12 @@ class AllGather(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore.communication.management import init |
|
|
|
>>> import mindspore.ops.operations as P |
|
|
|
>>> init('nccl') |
|
|
|
>>> class Net(nn.Cell): |
|
|
|
>>> def __init__(self): |
|
|
|
>>> super(Net, self).__init__() |
|
|
|
>>> self.allgather = AllGather(group="nccl_world_group") |
|
|
|
>>> self.allgather = P.AllGather(group="nccl_world_group") |
|
|
|
>>> |
|
|
|
>>> def construct(self, x): |
|
|
|
>>> return self.allgather(x) |
|
|
|
@@ -184,11 +186,12 @@ class ReduceScatter(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore.communication.management import init |
|
|
|
>>> import mindspore.ops.operations as P |
|
|
|
>>> init('nccl') |
|
|
|
>>> class Net(nn.Cell): |
|
|
|
>>> def __init__(self): |
|
|
|
>>> super(Net, self).__init__() |
|
|
|
>>> self.reducescatter = ReduceScatter(ReduceOp.SUM, group="nccl_world_group") |
|
|
|
>>> self.reducescatter = P.ReduceScatter(ReduceOp.SUM, group="nccl_world_group") |
|
|
|
>>> |
|
|
|
>>> def construct(self, x): |
|
|
|
>>> return self.reducescatter(x) |
|
|
|
@@ -246,11 +249,12 @@ class Broadcast(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore.communication.management import init |
|
|
|
>>> import mindspore.ops.operations as P |
|
|
|
>>> init('nccl') |
|
|
|
>>> class Net(nn.Cell): |
|
|
|
>>> def __init__(self): |
|
|
|
>>> super(Net, self).__init__() |
|
|
|
>>> self.broadcast = Broadcast(1) |
|
|
|
>>> self.broadcast = P.Broadcast(1) |
|
|
|
>>> |
|
|
|
>>> def construct(self, x): |
|
|
|
>>> return self.broadcast((x,)) |
|
|
|
|