diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 6675bbba03..0f965cdbb4 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -183,11 +183,11 @@ def get_bprop_allswap(self): all_swap_grad = AllSwap(self.group) if self.instance_name: instance_name = "grad" + self.instance_name - all_to_all_grad.set_prim_instance_name(instance_name) + all_swap_grad.set_prim_instance_name(instance_name) def bprop(x, send_size, recv_size, out, dout): dx = all_swap_grad(dout, recv_size, send_size) - return (dx,) + return (dx, zeros_like(send_size), zeros_like(recv_size)) return bprop diff --git a/tests/ut/python/communication/test_comm.py b/tests/ut/python/communication/test_comm.py index 618fa23ac3..d32907a2aa 100644 --- a/tests/ut/python/communication/test_comm.py +++ b/tests/ut/python/communication/test_comm.py @@ -27,7 +27,7 @@ from mindspore.nn import ReLU from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter from mindspore.ops.operations.comm_ops import Broadcast, AllSwap -from mindspore.ops.operations.math_ops import ReduceSum +from mindspore.ops.operations.array_ops import GatherV2 import mindspore # pylint: disable=W0212 @@ -127,14 +127,15 @@ class AllSwapNet(nn.Cell): self.dense = Dense(input_channel, out_channel) self.allswap = AllSwap() self.relu = ReLU() - self.reduce = ReduceSum() part_slice = batch_size / 2 self.send_size = Tensor([0, part_slice*out_channel, part_slice*out_channel], mindspore.int64) self.recv_size = Tensor([part_slice*out_channel, part_slice*out_channel, 0], mindspore.int64) + self.gatherv2 = GatherV2() + self.input = Tensor(np.ones([1]), mindspore.int32) def construct(self, x): - x = self.dense(x) x = self.allswap(x, self.send_size, self.recv_size) x = self.relu(x) + x = self.gatherv2(x, self.input, 0) return x @@ -180,8 +181,15 @@ def test_allswap(): """run_allswap""" context.set_context(mode=context.GRAPH_MODE) input_tensor = Tensor(np.ones((100, 20)), dtype=mindspore.float32) + label_tensor = Tensor(np.ones((1, 20)), dtype=mindspore.float32) network = AllSwapNet(100, 20, 20) - _executor.compile(network, input_tensor) + loss_fn = nn.SoftmaxCrossEntropyWithLogits() + optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), + learning_rate=0.1, + momentum=0.9) + network = WithLossCell(network, loss_fn) + network = TrainOneStepCell(network, optimizer) + _executor.compile(network, input_tensor, label_tensor) def run_reducescatter(op):