Merge pull request !2845 from zhoufeng/empty-graph-dump-irtags/v0.6.0-beta
| @@ -694,7 +694,7 @@ void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> & | |||||
| MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size(); | MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size(); | ||||
| } | } | ||||
| if (shapes.empty()) { | if (shapes.empty()) { | ||||
| MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes"; | |||||
| node->set_abstract(std::make_shared<abstract::AbstractNone>()); | |||||
| } else if (shapes.size() == 1) { | } else if (shapes.size() == 1) { | ||||
| // single output handle | // single output handle | ||||
| std::vector<int> shape_int; | std::vector<int> shape_int; | ||||
| @@ -1012,6 +1012,9 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN | |||||
| auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { | auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { | ||||
| auto partial = switch_node->input(input_index); | auto partial = switch_node->input(input_index); | ||||
| MS_EXCEPTION_IF_NULL(partial); | MS_EXCEPTION_IF_NULL(partial); | ||||
| if (IsValueNode<KernelGraph>(partial)) { | |||||
| return GetValueNode<KernelGraphPtr>(partial); | |||||
| } | |||||
| auto partial_cnode = partial->cast<CNodePtr>(); | auto partial_cnode = partial->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(partial_cnode); | MS_EXCEPTION_IF_NULL(partial_cnode); | ||||
| auto graph_node = partial_cnode->input(1); | auto graph_node = partial_cnode->input(1); | ||||
| @@ -411,8 +411,7 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod | |||||
| origin_switch_inputs[kCNodeSwitchCond]}; | origin_switch_inputs[kCNodeSwitchCond]}; | ||||
| for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { | for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { | ||||
| // 3.1 branch kernel graph and args | // 3.1 branch kernel graph and args | ||||
| KernelGraphPtr branch_fg; | |||||
| std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||||
| KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||||
| // 3.2 recurse sub graph | // 3.2 recurse sub graph | ||||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | ||||
| new_switch_inputs.push_back(branch_label); | new_switch_inputs.push_back(branch_label); | ||||
| @@ -456,8 +455,7 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull | |||||
| origin_switch_inputs[kCNodeSwitchCond]}; | origin_switch_inputs[kCNodeSwitchCond]}; | ||||
| for (size_t i = 0; i < branch_partial.size(); ++i) { | for (size_t i = 0; i < branch_partial.size(); ++i) { | ||||
| // 3.1 branch kernel graph and args | // 3.1 branch kernel graph and args | ||||
| KernelGraphPtr branch_fg; | |||||
| std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||||
| KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||||
| // 3.2 recurse sub graph | // 3.2 recurse sub graph | ||||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | ||||
| new_switch_inputs.push_back(branch_label); | new_switch_inputs.push_back(branch_label); | ||||
| @@ -468,8 +466,11 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull | |||||
| MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); | MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); | ||||
| } | } | ||||
| std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) { | |||||
| KernelGraphPtr AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) { | |||||
| if (!node.get()->isa<CNode>()) { | if (!node.get()->isa<CNode>()) { | ||||
| if (IsValueNode<KernelGraph>(node)) { | |||||
| return GetValueNode<KernelGraphPtr>(node); | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); | MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); | ||||
| } | } | ||||
| // 2.1 branch kernel graph and args | // 2.1 branch kernel graph and args | ||||
| @@ -484,7 +485,7 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A | |||||
| MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << "."; | MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << "."; | ||||
| } | } | ||||
| auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]); | auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]); | ||||
| return {partial_cnode, branch_kg}; | |||||
| return branch_kg; | |||||
| } | } | ||||
| void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, | void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, | ||||
| @@ -53,7 +53,7 @@ class AscendControlParser { | |||||
| static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | ||||
| const CNodePtr &last_label); | const CNodePtr &last_label); | ||||
| static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node); | |||||
| static KernelGraphPtr ParsePartial(NotNull<AnfNodePtr> node); | |||||
| static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph, | static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph, | ||||
| NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | ||||
| @@ -255,6 +255,9 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag, | |||||
| MS_EXCEPTION_IF_NULL(switch_cnode); | MS_EXCEPTION_IF_NULL(switch_cnode); | ||||
| auto partial = switch_cnode->input(input_index); | auto partial = switch_cnode->input(input_index); | ||||
| MS_EXCEPTION_IF_NULL(partial); | MS_EXCEPTION_IF_NULL(partial); | ||||
| if (IsValueNode<KernelGraph>(partial)) { | |||||
| return {}; | |||||
| } | |||||
| auto partial_cnode = partial->cast<CNodePtr>(); | auto partial_cnode = partial->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(partial_cnode); | MS_EXCEPTION_IF_NULL(partial_cnode); | ||||
| auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); | auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); | ||||
| @@ -387,18 +387,16 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { | |||||
| } else { | } else { | ||||
| kernel_info->SetFeatureMapFlag(true); | kernel_info->SetFeatureMapFlag(true); | ||||
| } | } | ||||
| // if output is a tuple tensor,now can use for loop to handle tuple tensor | |||||
| output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter); | |||||
| } | } | ||||
| new_parameter->set_kernel_info(kernel_info); | new_parameter->set_kernel_info(kernel_info); | ||||
| // create kernel_build_info for new parameter | // create kernel_build_info for new parameter | ||||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | ||||
| // create init data type, | // create init data type, | ||||
| std::vector<TypeId> init_data_type = {}; | std::vector<TypeId> init_data_type = {}; | ||||
| for (size_t i = 0; i < output_tensor_num; i++) { | |||||
| TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, i); | |||||
| init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); | |||||
| } | |||||
| TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0); | |||||
| init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); | |||||
| // set the format of parameter to DEFAULT_FORMAT | // set the format of parameter to DEFAULT_FORMAT | ||||
| kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT)); | kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT)); | ||||
| // set parameter initaial device data type | // set parameter initaial device data type | ||||
| @@ -590,7 +590,8 @@ TEST_F(AnfRuntimeAlgorithmTest, SetOutputInferTypeAndShape) { | |||||
| std::vector<TypeId> none_types = {}; | std::vector<TypeId> none_types = {}; | ||||
| std::vector<std::vector<size_t>> none_shapes = {}; | std::vector<std::vector<size_t>> none_shapes = {}; | ||||
| EXPECT_THROW(AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, nullptr), std::runtime_error); | EXPECT_THROW(AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, nullptr), std::runtime_error); | ||||
| EXPECT_THROW(AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, add.get()), std::runtime_error); | |||||
| AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, add.get()); | |||||
| EXPECT_EQ((*add->abstract()), abstract::AbstractNone()); | |||||
| // set single input | // set single input | ||||
| std::vector<TypeId> single_types = {kFloat32->type_id()}; | std::vector<TypeId> single_types = {kFloat32->type_id()}; | ||||
| std::vector<std::vector<size_t>> single_shapes = {{2, 32, 224, 224}}; | std::vector<std::vector<size_t>> single_shapes = {{2, 32, 224, 224}}; | ||||