|
|
|
@@ -14,7 +14,7 @@ |
|
|
|
|
|
|
|
""" test Communicate """ |
|
|
|
import numpy as np |
|
|
|
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp |
|
|
|
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter |
|
|
|
from mindspore.ops.operations.comm_ops import Broadcast |
|
|
|
from mindspore.communication.management import HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, GlobalComm, init |
|
|
|
from mindspore.communication._comm_helper import Backend |
|
|
|
@@ -78,6 +78,19 @@ class AllGatherNet(nn.Cell): |
|
|
|
x = self.allgather(x) |
|
|
|
return self.relu(x) |
|
|
|
|
|
|
|
class ReduceScatterNet(nn.Cell): |
|
|
|
"""ReduceScatterNet definition""" |
|
|
|
def __init__(self, input_channel, out_channel, op): |
|
|
|
super(ReduceScatterNet, self).__init__() |
|
|
|
self.dense = Dense(input_channel, out_channel) |
|
|
|
self.reducescatter = ReduceScatter(op) |
|
|
|
self.relu = ReLU() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.dense(x) |
|
|
|
x = self.reducescatter(x) |
|
|
|
return self.relu(x) |
|
|
|
|
|
|
|
class AlltoAllNet(nn.Cell): |
|
|
|
"""AlltoAllNet definition""" |
|
|
|
def __init__(self, input_channel, out_channel): |
|
|
|
@@ -126,6 +139,25 @@ def test_allgather(): |
|
|
|
network = TrainOneStepCell(network, optimizer) |
|
|
|
_executor.compile(network, input_tensor, label_tensor) |
|
|
|
|
|
|
|
def run_reducescatter(op): |
|
|
|
"""run_reducescatter""" |
|
|
|
context.set_context(mode=context.GRAPH_MODE) |
|
|
|
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) |
|
|
|
label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32)) |
|
|
|
network = ReduceScatterNet(2, 1, op) |
|
|
|
loss_fn = nn.SoftmaxCrossEntropyWithLogits() |
|
|
|
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), |
|
|
|
learning_rate=0.1, |
|
|
|
momentum=0.9) |
|
|
|
network = WithLossCell(network, loss_fn) |
|
|
|
network = TrainOneStepCell(network, optimizer) |
|
|
|
_executor.compile(network, input_tensor, label_tensor) |
|
|
|
|
|
|
|
def test_reducescatter(): |
|
|
|
"""test_reducescatter""" |
|
|
|
context.set_context(mode=context.GRAPH_MODE) |
|
|
|
run_reducescatter(ReduceOp.SUM) |
|
|
|
|
|
|
|
def test_broadcast(): |
|
|
|
"""test_broadcast""" |
|
|
|
context.set_context(mode=context.GRAPH_MODE) |
|
|
|
|