From 093ef784de9f5f9004c200b980f8c52b618be7bc Mon Sep 17 00:00:00 2001 From: yao_yf Date: Mon, 26 Apr 2021 08:06:42 +0800 Subject: [PATCH] dont insert virtualoutput for scalar --- .../ccsrc/frontend/parallel/step_parallel.cc | 7 ++- .../ut/python/parallel/test_virtual_output.py | 55 +++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 3548f3623d..bcc68ad76c 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1040,14 +1040,17 @@ void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector cnode = node_pair.first->cast(); last_indexs[last_node_index] = size_t(node_pair.second); } + auto pre_node = cnode->input(last_indexs[last_node_index]); + Shapes shape_outputs = GetNodeShape(pre_node); + if (shape_outputs[0].empty()) { + continue; + } FuncGraphPtr func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); OperatorParams params; OperatorAttrs attrs; OperatorArgs args = std::make_pair(attrs, params); Operator op = std::make_pair(VIRTUAL_OUTPUT, args); - auto pre_node = cnode->input(last_indexs[last_node_index]); - Shapes shape_outputs = GetNodeShape(pre_node); InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT); auto virtual_output_node = cnode->input(last_indexs[last_node_index]); AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone(); diff --git a/tests/ut/python/parallel/test_virtual_output.py b/tests/ut/python/parallel/test_virtual_output.py index 84a328d47a..e44a769985 100644 --- a/tests/ut/python/parallel/test_virtual_output.py +++ b/tests/ut/python/parallel/test_virtual_output.py @@ -97,6 +97,24 @@ class ReshapeMulNet(nn.Cell): out = self.mul(weight, self.mul_weight) return out +class ParallelMulNet(nn.Cell): + def __init__(self, dense_in_channel=2048, dense_out_channel=250): + super().__init__() + weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32) + bias_np = np.full((dense_out_channel,), 0.01, dtype=np.float32) + self.flat = nn.Flatten() + self.dense = nn.Dense(in_channels=dense_in_channel, + out_channels=dense_out_channel, + weight_init=Tensor(weight_np), + bias_init=Tensor(bias_np), + has_bias=True) + self.mul = P.Mul() + def construct(self, inputs): + x = self.flat(inputs) + x = self.dense(x) + x = self.mul(x, x) + return x + def compile_graph(x, net): net.set_auto_parallel() net.set_train(False) @@ -104,6 +122,13 @@ def compile_graph(x, net): strategies = _executor._get_shard_strategy(net) return strategies +def compile_graph_two_input(x, y, net): + net.set_auto_parallel() + net.set_train(False) + _executor.compile(net, x, y, auto_parallel_mode=True) + strategies = _executor._get_shard_strategy(net) + return strategies + def test_dense_relu_semi_auto(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) @@ -250,3 +275,33 @@ def test_reshape_mul_auto(): for (k, v) in strategies.items(): if re.search('VirtualOutput-op', k) is not None: assert v[0][0] == 1 + +def test_scalar_output_semi_auto(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) + net = ParallelMulNet() + loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') + eval_net = nn.WithEvalCell(net, loss_fn) + x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01) + label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01) + strategies = compile_graph_two_input(x, label, eval_net) + count = 0 + for (k, v) in strategies.items(): + if re.search('VirtualOutput-op', k) is not None: + assert v[0][0] == 8 + count += 1 + assert count == 1 + +def test_scalar_output_auto(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False) + net = ParallelMulNet() + loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') + eval_net = nn.WithEvalCell(net, loss_fn) + x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01) + label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01) + strategies = compile_graph_two_input(x, label, eval_net) + count = 0 + for (k, v) in strategies.items(): + if re.search('VirtualOutput-op', k) is not None: + assert v[0][0] == 8 + count += 1 + assert count == 1