From c575b467dbf7a3c084d13d3781a6f7472386d9c2 Mon Sep 17 00:00:00 2001 From: y00500818 Date: Thu, 19 Nov 2020 17:09:50 +0800 Subject: [PATCH] parser one to many --- parser/common/CMakeLists.txt | 1 + parser/common/module.mk | 1 + parser/common/parser_utils.cc | 202 +++++++++++++++++++++++++ parser/common/parser_utils.h | 37 +++++ parser/onnx/onnx_parser.cc | 3 + parser/tensorflow/tensorflow_parser.cc | 11 +- 6 files changed, 253 insertions(+), 2 deletions(-) create mode 100644 parser/common/parser_utils.cc create mode 100644 parser/common/parser_utils.h diff --git a/parser/common/CMakeLists.txt b/parser/common/CMakeLists.txt index f5e987b..91d8d80 100644 --- a/parser/common/CMakeLists.txt +++ b/parser/common/CMakeLists.txt @@ -24,6 +24,7 @@ set(SRC_LIST "pass_manager.cc" "parser_fp16_t.cc" "thread_pool.cc" + "parser_utils.cc" ) ############ libparser_common.so ############ diff --git a/parser/common/module.mk b/parser/common/module.mk index ae8d9a2..ec2a1ac 100644 --- a/parser/common/module.mk +++ b/parser/common/module.mk @@ -36,6 +36,7 @@ COMMON_LOCAL_SRC_FILES := \ pass_manager.cc \ parser_fp16_t.cc \ thread_pool.cc \ + parser_utils.cc \ FMK_COMMON_SRC_FILES := \ # ../../common/fmk_error_codes.cc \ diff --git a/parser/common/parser_utils.cc b/parser/common/parser_utils.cc new file mode 100644 index 0000000..3353022 --- /dev/null +++ b/parser/common/parser_utils.cc @@ -0,0 +1,202 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "parser_utils.h" +#include "external/ge/ge_api_types.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "framework/omg/parser/parser_types.h" +#include "graph/anchor.h" +#include "graph/compute_graph.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_adapter.h" +#include "graph/utils/op_desc_utils.h" +#include "register/op_registry.h" + +namespace ge { +Status ParserUtils::ExpandOneToManyGraph(Graph &graph) { + GELOGD("Begin run ParserUtils::ExpandOneToManyGraph."); + for (const auto &gn : graph.GetDirectNode()) { + NodePtr n = NodeAdapter::GNode2Node(gn); + GE_CHECK_NOTNULL(n); + std::string ori_type; + (void)AttrUtils::GetStr(n->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, ori_type); + domi::ParseOpToGraphFunc parse_op_to_graph_func = + domi::OpRegistry::Instance()->GetParseOpToGraphFunc(n->GetType(), ori_type); + if (parse_op_to_graph_func == nullptr) { + GELOGD("node:%s type:%s ori type:%s has no parse_op_to_graph_func.", + n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str()); + continue; + } + GELOGI("node:%s type:%s ori type:%s has registered one to many parser func.", + n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str()); + Graph subgraph("one_to_many_graph"); + Operator op = OpDescUtils::CreateOperatorFromNode(n); + Status ret = parse_op_to_graph_func(op, subgraph); + if (ret != SUCCESS) { + GELOGE(FAILED, "Get one to many graph failed for op:%s.", op.GetName().c_str()); + return FAILED; + } + ret = ExpandNodeToSubgraph(subgraph, n, graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "Expand one to many graph failed for op:%s.", op.GetName().c_str()); + return FAILED; + } + } + GELOGD("run ParserUtils::ExpandOneToManyGraph success."); + return SUCCESS; +} + +Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph) { + ComputeGraphPtr sub_compute_graph = GraphUtils::GetComputeGraph(subgraph); + GE_CHECK_NOTNULL(sub_compute_graph); + ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + // add subgraph node to graph. + std::unordered_map all_new_nodes; + std::vector input_nodes; + for (const auto &n : sub_compute_graph->GetDirectNode()) { + auto new_node = compute_graph->AddNode(n); + GE_CHECK_NOTNULL(new_node); + all_new_nodes[new_node->GetName()] = new_node; + if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Set owner graph for node:%s failed.", new_node->GetName().c_str()); + return FAILED; + } + + if (new_node->GetType() == ge::parser::DATA) { + input_nodes.emplace_back(new_node); + } + } + + // handle input context. + Status ret = HandleInputContext(node, input_nodes, compute_graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "run ParserUtils::HandleInputContext failed."); + return FAILED; + } + + // handle output context. + std::vector> out_node_index = sub_compute_graph->GetGraphOutNodesInfo(); + ret = HandleOutputContext(node, out_node_index); + if (ret != SUCCESS) { + GELOGE(FAILED, "run ParserUtils::HandleOutputContext failed."); + return FAILED; + } + + graphStatus graph_status = GraphUtils::RemoveNodeWithoutRelink(compute_graph, node); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove node:%s failed.", node->GetName().c_str()); + return FAILED; + } + graph_status = compute_graph->TopologicalSorting(); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(FAILED, "Topological sorting failed."); + return FAILED; + } + return SUCCESS; +} + +Status ParserUtils::HandleInputContext(const NodePtr &node, + const std::vector &input_nodes, + const ComputeGraphPtr &compute_graph) { + GE_CHECK_NOTNULL(node); + for (const auto &in_n : input_nodes) { + GE_CHECK_NOTNULL(in_n); + int index; + if (!AttrUtils::GetInt(in_n->GetOpDesc(), ATTR_NAME_INDEX, index)) { + GELOGE(FAILED, "Get attr index of node:%s failed.", in_n->GetName().c_str()); + return FAILED; + } + GELOGD("Begin to handle input node:%s with index:%d.", in_n->GetName().c_str(), index); + // get node's in data anchor and peer out anchor + auto node_in_anchor = node->GetInDataAnchor(index); + GE_CHECK_NOTNULL(node_in_anchor); + auto src_out_anchor = node_in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(src_out_anchor); + auto data_out_anchor = in_n->GetOutDataAnchor(0); + GE_CHECK_NOTNULL(data_out_anchor); + for (const auto &peer_in_anchor : data_out_anchor->GetPeerInDataAnchors()) { + // add data edge + graphStatus ret = GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "remove data out anchor and peer in anchor failed."); + return FAILED; + } + ret = GraphUtils::RemoveEdge(src_out_anchor, node_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "remove node in anchor and peer out anchor failed."); + return FAILED; + } + ret = GraphUtils::AddEdge(src_out_anchor, peer_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "link node's peer out anchor and data's peer in anchor failed."); + return FAILED; + } + + // add control edge + if (node->GetInControlAnchor() != nullptr) { + for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) { + graphStatus ret = GraphUtils::AddEdge(out_anchor, peer_in_anchor->GetOwnerNode()->GetInControlAnchor()); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "add control edge failed."); + return FAILED; + } + } + } + } + graphStatus ret = GraphUtils::RemoveNodeWithoutRelink(compute_graph, in_n); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "remove node:%s failed.", in_n->GetName().c_str()); + return FAILED; + } + } + return SUCCESS; +} + +Status ParserUtils::HandleOutputContext(const NodePtr &node, + const std::vector> &out_node_index) { + GE_CHECK_NOTNULL(node); + GELOGD("The size of out node is %zu", out_node_index.size()); + for (size_t index = 0; index < out_node_index.size(); index++) { + auto node_out_anchor = node->GetOutDataAnchor(index); + if (node_out_anchor == nullptr) { + continue; + } + + NodePtr out_node = out_node_index[index].first; + int32_t out_index = out_node_index[index].second; + GELOGD("Begin to handle output node:%s[%zu] with index:%zu", out_node->GetName().c_str(), out_index, index); + auto src_out_anchor = out_node->GetOutDataAnchor(out_index); // get out node's out anchor. + GE_CHECK_NOTNULL(src_out_anchor); + for (const auto &dest_in_anchor : node_out_anchor->GetPeerInDataAnchors()) { + graphStatus ret = GraphUtils::RemoveEdge(node_out_anchor, dest_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "remove node's out anchor and peer in anchor failed."); + return FAILED; + } + ret = GraphUtils::AddEdge(src_out_anchor, dest_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "link node's peer out anchor and out node's out anchor failed."); + return FAILED; + } + } + } + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/parser/common/parser_utils.h b/parser/common/parser_utils.h new file mode 100644 index 0000000..1be6d70 --- /dev/null +++ b/parser/common/parser_utils.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef PARSER_COMMON_PARSER_UTILS_H_ +#define PARSER_COMMON_PARSER_UTILS_H_ + +#include "graph/graph.h" +#include "graph/node.h" +#include "external/ge/ge_api_error_codes.h" + +namespace ge { +class ParserUtils { + public: + static Status ExpandOneToManyGraph(Graph &graph); + + private: + static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph); + static Status HandleInputContext(const NodePtr &node, + const std::vector &input_nodes, + const ComputeGraphPtr &compute_graph); + static Status HandleOutputContext(const NodePtr &node, const std::vector> &out_node_index); +}; +} // namespace ge +#endif // PARSER_COMMON_PARSER_UTILS_H_ diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 25e6f8c..536ec5c 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -30,6 +30,7 @@ #include "parser/common/pre_checker.h" #include "parser/common/acl_graph_parser_util.h" #include "parser/common/model_saver.h" +#include "parser/common/parser_utils.h" #include "parser/onnx/onnx_util.h" #include "register/op_registry.h" @@ -535,6 +536,8 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { } graph.SetInputs(input_ops).SetOutputs(output_indexs); + GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); + UpdateFormat(graph); GELOGI("Onnx model parser success."); diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 7844bfd..bce0002 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -45,6 +45,7 @@ #include "parser/common/pass_manager.h" #include "parser/common/pre_checker.h" #include "parser/common/thread_pool.h" +#include "parser/common/parser_utils.h" #include "parser/tensorflow/tensorflow_custom_parser_adapter.h" #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" #include "parser/tensorflow/tensorflow_fusion_op_parser.h" @@ -201,8 +202,8 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque auto i = subgraph_name_to_index.second; auto subgraph_iname = op_desc->GetSubgraphInstanceName(i); if (subgraph_iname.empty()) { - GELOGE(PARAM_INVALID, "The subgraph index %u of node %s is empty", i, node->GetName().c_str()); - return PARAM_INVALID; + GELOGW("The subgraph index %u of node %s is empty", i, node->GetName().c_str()); + continue; } // A function may be referenced multiple times in TF, change the graph name to ensure it is unique in GE @@ -1414,6 +1415,8 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro DeleteFuisonNodeDef(); GE_RETURN_IF_ERROR(AddEdges(graph)); + Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph); + GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph)); GE_RETURN_IF_ERROR(RemoveIsolateNode(graph)); GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph)); GE_RETURN_IF_ERROR(graph->TopologicalSorting()); @@ -2208,6 +2211,10 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, GELOGD("[TF Parser] Add framework node success"); ret = AddEdges(graph); + + Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph); + GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph)); + DeleteFuisonNodeDef(); GE_CHK_STATUS_EXEC(ret, return ret, "AddEdges failed"); GELOGD("[TF Parser] Add edges success");