diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 5f896282dc..81ad02e787 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -694,7 +694,7 @@ void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector & MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size(); } if (shapes.empty()) { - MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes"; + node->set_abstract(std::make_shared()); } else if (shapes.size() == 1) { // single output handle std::vector shape_int; @@ -1012,6 +1012,9 @@ std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { auto partial = switch_node->input(input_index); MS_EXCEPTION_IF_NULL(partial); + if (IsValueNode(partial)) { + return GetValueNode(partial); + } auto partial_cnode = partial->cast(); MS_EXCEPTION_IF_NULL(partial_cnode); auto graph_node = partial_cnode->input(1); diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 18e71d74e3..0c97116c6e 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -411,8 +411,7 @@ void AscendControlParser::RecurseSwitch(NotNull kg, NotNull kg, NotNull origin_switch_inputs[kCNodeSwitchCond]}; for (size_t i = 0; i < branch_partial.size(); ++i) { // 3.1 branch kernel graph and args - KernelGraphPtr branch_fg; - std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); // 3.2 recurse sub graph CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); new_switch_inputs.push_back(branch_label); @@ -468,8 +466,11 @@ void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); } -std::tuple AscendControlParser::ParsePartial(NotNull node) { +KernelGraphPtr AscendControlParser::ParsePartial(NotNull node) { if (!node.get()->isa()) { + if (IsValueNode(node)) { + return GetValueNode(node); + } MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); } // 2.1 branch kernel graph and args @@ -484,7 +485,7 @@ std::tuple AscendControlParser::ParsePartial(NotNull(partial_inputs[kCNodePartialFunc]); - return {partial_cnode, branch_kg}; + return branch_kg; } void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index 82479fa527..7530f2019e 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -53,7 +53,7 @@ class AscendControlParser { static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, const CNodePtr &last_label); - static std::tuple ParsePartial(NotNull node); + static KernelGraphPtr ParsePartial(NotNull node); static void InsertMultipleAssignToGraph(NotNull from_graph, NotNull to_graph, NotNull from, NotNull to); diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index c6ef41f2ba..f361cb26ca 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -255,6 +255,9 @@ static void UpdateRealInput(NotNull graph, bool split_flag, MS_EXCEPTION_IF_NULL(switch_cnode); auto partial = switch_cnode->input(input_index); MS_EXCEPTION_IF_NULL(partial); + if (IsValueNode(partial)) { + return {}; + } auto partial_cnode = partial->cast(); MS_EXCEPTION_IF_NULL(partial_cnode); auto ret = std::vector(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index f79363fd03..264e2c661b 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -387,18 +387,16 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { } else { kernel_info->SetFeatureMapFlag(true); } - // if output is a tuple tensor,now can use for loop to handle tuple tensor - output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter); } new_parameter->set_kernel_info(kernel_info); // create kernel_build_info for new parameter auto kernel_build_info_builder = std::make_shared(); // create init data type, std::vector init_data_type = {}; - for (size_t i = 0; i < output_tensor_num; i++) { - TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, i); - init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); - } + + TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0); + init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); + // set the format of parameter to DEFAULT_FORMAT kernel_build_info_builder->SetOutputsFormat(std::vector(output_tensor_num, kOpFormat_DEFAULT)); // set parameter initaial device data type diff --git a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc index 2ea2453381..4c94cdde57 100644 --- a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc +++ b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc @@ -590,7 +590,8 @@ TEST_F(AnfRuntimeAlgorithmTest, SetOutputInferTypeAndShape) { std::vector none_types = {}; std::vector> none_shapes = {}; EXPECT_THROW(AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, nullptr), std::runtime_error); - EXPECT_THROW(AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, add.get()), std::runtime_error); + AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, add.get()); + EXPECT_EQ((*add->abstract()), abstract::AbstractNone()); // set single input std::vector single_types = {kFloat32->type_id()}; std::vector> single_shapes = {{2, 32, 224, 224}};