From b87af40416550d13a9955d94e686a3bba6da133f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=8D=8E?= Date: Mon, 22 Mar 2021 14:03:41 +0800 Subject: [PATCH] onnx_if_loop --- .gitmodules | 4 - metadef | 1 - parser/common/parser_types.cc | 3 + parser/onnx/CMakeLists.txt | 4 +- parser/onnx/onnx_data_parser.cc | 14 +- parser/onnx/onnx_data_parser.h | 6 + parser/onnx/onnx_if_subgraph_adapter.cc | 125 ++++++++ parser/onnx/onnx_if_subgraph_adapter.h | 47 +++ parser/onnx/onnx_parser.cc | 379 +++++++++++++++++++++--- parser/onnx/onnx_parser.h | 24 +- parser/onnx/onnx_retval_parser.cc | 48 +++ parser/onnx/onnx_retval_parser.h | 29 ++ parser/onnx/onnx_subgraph_adapter.h | 56 ++++ parser/onnx/onnx_util.h | 3 + 14 files changed, 695 insertions(+), 48 deletions(-) delete mode 160000 metadef create mode 100644 parser/onnx/onnx_if_subgraph_adapter.cc create mode 100644 parser/onnx/onnx_if_subgraph_adapter.h create mode 100644 parser/onnx/onnx_retval_parser.cc create mode 100644 parser/onnx/onnx_retval_parser.h create mode 100644 parser/onnx/onnx_subgraph_adapter.h diff --git a/.gitmodules b/.gitmodules index 4b23427..e69de29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +0,0 @@ -[submodule "metadef"] - path = metadef - url = https://gitee.com/ascend/metadef.git - branch = master diff --git a/metadef b/metadef deleted file mode 160000 index 86781b7..0000000 --- a/metadef +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 86781b7e8ce21d2b901406cc3619d6bea2aeb18e diff --git a/parser/common/parser_types.cc b/parser/common/parser_types.cc index b53d37a..b4c0f57 100644 --- a/parser/common/parser_types.cc +++ b/parser/common/parser_types.cc @@ -438,6 +438,9 @@ const char *HVDCALLBACKALLGATHER = "HorovodAllgather"; const char *HVDCALLBACKBROADCAST = "HorovodBroadcast"; const char *HVDWAIT = "HorovodWait"; +// onnx output operator +const char *_RETVAL = "_Retval"; + /// /// @brief Magic number of model file /// diff --git a/parser/onnx/CMakeLists.txt b/parser/onnx/CMakeLists.txt index 77cdcf1..20f9bfa 100644 --- a/parser/onnx/CMakeLists.txt +++ b/parser/onnx/CMakeLists.txt @@ -8,7 +8,9 @@ set(SRC_LIST "onnx_parser.cc" "onnx_data_parser.cc" "onnx_util.cc" - "onnx_constant_parser.cc" + "onnx_constant_parser.cc" + "onnx_retval_parser.cc" + "onnx_if_subgraph_adapter.cc" ) protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) diff --git a/parser/onnx/onnx_data_parser.cc b/parser/onnx/onnx_data_parser.cc index 29f966a..9305e3f 100644 --- a/parser/onnx/onnx_data_parser.cc +++ b/parser/onnx/onnx_data_parser.cc @@ -35,6 +35,11 @@ Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) GELOGE(FAILED, "parse shape of data op %s from model failed", op_def.GetName().c_str()); return FAILED; } + // Subgraph data operator don't need parse input shape + // the shape mappings from parent node input + if (IsSubgraphOp()) { + return SUCCESS; + } if (ParseInputFromUser(op_def) != SUCCESS) { GELOGE(FAILED, "parse shape of data op %s from user failed", op_def.GetName().c_str()); @@ -72,15 +77,23 @@ Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator & // Get attr t:'input_tensor' form NodeProto int64_t data_type = 1; int64_t index = 0; + isSubgraphOp_ = false; for (auto it : node->attribute()) { if (it.name() == ge::kAttrNameInput) { data_type = ParseInputTensor(it); } else if (it.name() == ge::kAttrNameIndex) { index = it.i(); GELOGI("The node has attribute with index: %ld", index); + } else if (it.name() == ge::kAttrNameIsSubgraphOp) { + isSubgraphOp_ = true; } } + op_def.SetAttr(ge::ATTR_NAME_INDEX, index); + if (IsSubgraphOp()) { + return SUCCESS; + } + // Trans onnx type to ge type DataType type = OnnxUtil::ConvertOnnxDataType(data_type); if (type == ge::DataType::DT_UNDEFINED) { @@ -88,7 +101,6 @@ Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator & return FAILED; } op_def.SetAttr(ge::DATA_ATTR_NAME_DATA_TYPE, static_cast(type)); - op_def.SetAttr(ge::ATTR_NAME_INDEX, index); return SUCCESS; } diff --git a/parser/onnx/onnx_data_parser.h b/parser/onnx/onnx_data_parser.h index 9650af6..373a085 100644 --- a/parser/onnx/onnx_data_parser.h +++ b/parser/onnx/onnx_data_parser.h @@ -32,11 +32,17 @@ class PARSER_FUNC_VISIBILITY OnnxDataParser : public OnnxOpParser { Status ParseInputFromUser(const ge::Operator &op_def); + bool IsSubgraphOp() { + return isSubgraphOp_; + } + int64_t ParseInputTensor(const ge::onnx::AttributeProto &attribute); std::vector model_input_dims_v_; std::vector user_input_dims_v_; + + bool isSubgraphOp_; }; } // namespace ge diff --git a/parser/onnx/onnx_if_subgraph_adapter.cc b/parser/onnx/onnx_if_subgraph_adapter.cc new file mode 100644 index 0000000..c4135d8 --- /dev/null +++ b/parser/onnx/onnx_if_subgraph_adapter.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/onnx/onnx_if_subgraph_adapter.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "onnx_parser.h" + +using domi::ONNX; + +namespace ge{ +namespace { +const std::map kAttrNames = {{"then_branch", 0}, {"else_branch", 1}}; +const int kIfNodeMaxAttrSize = 2; +} +Status OnnxIfSubgraphAdapter::AdaptAndFindAllOnnxSubgraphs(ge::onnx::NodeProto *parent_node, + std::deque &onnx_graph_tasks, + std::map &name_to_onnx_graph) { + GE_CHECK_NOTNULL(parent_node); + GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), + parent_node->op_type().c_str()); + + auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graph_tasks, name_to_onnx_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Parse if node failed."); + return ret; + } + return SUCCESS; +} + +Status OnnxIfSubgraphAdapter::ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, + std::deque &onnx_graph_tasks, + std::map &name_to_onnx_graph) { + if (parent_node->attribute_size() != kIfNodeMaxAttrSize) { + GELOGE(FAILED, "invalid graph, node->attribute_size():%d.", parent_node->attribute_size()); + return FAILED; + } + GELOGI("node attribute size:%d.", parent_node->attribute_size()); + std::set all_inputs; + std::vector onnx_graphs; + for (int i = 0; i < parent_node->attribute_size(); i++) { + ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i); + std::string attr_name = attribute->name(); + auto itr = kAttrNames.find(attr_name); + if (itr == kAttrNames.end()) { + GELOGE(FAILED, "invalid attribute name:%s.", attr_name.c_str()); + return FAILED; + } + ge::onnx::GraphProto *onnx_graph = attribute->mutable_g(); + + std::string sub_graph_name = parent_node->name() + "_" + std::to_string(itr->second) + "_" + itr->first; + GELOGI("Start parse if attribute:%s, subgraph name:%s.", attr_name.c_str(), sub_graph_name.c_str()); + name_to_onnx_graph[sub_graph_name] = onnx_graph; + onnx_graph_tasks.push_back(onnx_graph); + + auto ret = GetSubgraphsAllInputs(parent_node, onnx_graph, all_inputs); + if (ret != SUCCESS) { + GELOGE(ret, "get subgraps all inputs failed, attr_name:%s.", attr_name.c_str()); + return ret; + } + onnx_graphs.emplace_back(onnx_graph); + } + for (auto onnx_graph : onnx_graphs) { + AddInputNodeForGraph(all_inputs, onnx_graph); + } + AddInputForParentNode(all_inputs, parent_node); + return SUCCESS; +} + +Status OnnxIfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::NodeProto *parent_node, ge::onnx::GraphProto *sub_graph, + std::set &all_inputs) { + std::set graph_inputs; + std::set graph_outputs; + for (int i = 0; i < sub_graph->node_size(); i++) { + ge::onnx::NodeProto *node_proto = sub_graph->mutable_node(i); + std::string node_name = node_proto->name(); + + for (int j = 0; j < node_proto->input_size(); j++) { + graph_inputs.emplace(node_proto->input(j)); + } + + for (int j = 0; j < node_proto->output_size(); j++) { + graph_outputs.emplace(node_proto->output(j)); + } + } + + for (auto input : graph_inputs) { + auto out_iter = graph_outputs.find(input); + if (out_iter == graph_outputs.end()) { + // Record input node needed to be construct + all_inputs.emplace(input); + } + } + + return SUCCESS; +} + +void OnnxIfSubgraphAdapter::AddInputNodeForGraph(const std::set &all_inputs, + ge::onnx::GraphProto *onnx_graph) { + for (auto input_name : all_inputs) { + ge::onnx::ValueInfoProto *value_info = onnx_graph->add_input(); + value_info->set_name(input_name); + } +} + +void OnnxIfSubgraphAdapter::AddInputForParentNode(const std::set &all_inputs, + ge::onnx::NodeProto *parent_node) { + for (auto input_name : all_inputs) { + parent_node->add_input(input_name); + } +} +} // namespace ge diff --git a/parser/onnx/onnx_if_subgraph_adapter.h b/parser/onnx/onnx_if_subgraph_adapter.h new file mode 100644 index 0000000..37b855e --- /dev/null +++ b/parser/onnx/onnx_if_subgraph_adapter.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_ONNX_ONNX_IF_SUBGRAPH_ADAPTER_H_ +#define GE_PARSER_ONNX_ONNX_IF_SUBGRAPH_ADAPTER_H_ + +#include +#include +#include "onnx_subgraph_adapter.h" +#include "proto/onnx/ge_onnx.pb.h" + +using ge::onnx::NodeProto; + +namespace ge { +class PARSER_FUNC_VISIBILITY OnnxIfSubgraphAdapter : public OnnxSubgraphAdapter { + public: + /// @brief parse params + /// @param [in] parent_op parent op + /// @return SUCCESS parse success + /// @return FAILED Parse failed + Status AdaptAndFindAllOnnxSubgraphs(ge::onnx::NodeProto *parent_op, + std::deque &onnx_graph_tasks, + std::map &name_to_onnx_graph) override; +private: + Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::deque &onnx_graph_tasks, + std::map &name_to_onnx_graph); + Status GetSubgraphsAllInputs(ge::onnx::NodeProto *parent_node, ge::onnx::GraphProto *sub_graph, + std::set &all_inputs); + void AddInputNodeForGraph(const std::set &all_inputs, ge::onnx::GraphProto *onnx_graph); + void AddInputForParentNode(const std::set &all_inputs, ge::onnx::NodeProto *parent_node); +}; +} // namespace ge + +#endif // GE_PARSER_ONNX_ONNX_IF_SUBGRAPH_ADAPTER_H_ diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 8745fb9..12dc5bd 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -37,6 +37,10 @@ #include "parser/onnx/onnx_util.h" #include "register/op_registry.h" #include "register/register_fmk_types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "onnx_if_subgraph_adapter.h" namespace ge { graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, @@ -95,7 +99,7 @@ graphStatus aclgrphParseONNX(const char *model_file, } GE_CHECK_NOTNULL(model_parser); - // parse caffe model_file to GE graph + // parse onnx model_file to GE graph ge::graphStatus ret = model_parser->Parse(model_file, graph); if (ret != ge::SUCCESS) { GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); @@ -144,25 +148,187 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, namespace ge { namespace { const std::map kOnnxOpMap = { - {ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT}, + {ge::kOpTypeInput, ge::parser::DATA}, + {ge::kOpTypeConstant, ge::parser::CONSTANT}, + {ge::kOpTypeOutput, ge::parser::_RETVAL}, }; -const char* const MATMULV2 = "MatMulV2"; + +const std::vector kMakeOperatorNotByIr = {ge::parser::_RETVAL}; +const char *const MATMULV2 = "MatMulV2"; const std::vector kNoNeedUpdateFormat = {MATMULV2}; const int64_t kDimValue = 1; + +struct ParseArg { + ge::onnx::GraphProto *onnx_graph; + ge::NodePtr parent_node; + std::string graph_name; + uint32_t subgraph_index; +}; + +Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque &args) { + GELOGI("Gen subgraph parse tasks start"); + for (auto &node : parent_graph->GetDirectNode()) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (const auto subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) { + auto i = subgraph_name_to_index.second; + auto subgraph_iname = subgraph_name_to_index.first; + if (subgraph_iname.empty()) { + GELOGW("The subgraph index %u of node %s is empty", i, node->GetName().c_str()); + continue; + } + + // A function may be referenced multiple times in TF, change the graph name to ensure it is unique in GE + auto unique_name = node->GetName() + "_" + std::to_string(i) + "_" + subgraph_iname; + + GELOGD("Add subgraph parse task to the queue, node %s, index %u, subgraph instance name %s", + node->GetName().c_str(), i, unique_name.c_str()); + args.push_back({nullptr, node, unique_name, i}); + } + } + GELOGI("Gen subgraph parse tasks end"); + return SUCCESS; +} + +Status BuildLinkForSonAndParentGraph(const ge::ComputeGraphPtr &sub_graph, const ParseArg &arg) { + if (arg.parent_node == nullptr) { + return SUCCESS; + } + auto parent_node = arg.parent_node; + auto index = arg.subgraph_index; + auto ret = ge::NodeUtils::SetSubgraph(*parent_node, index, sub_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to set subgraph %s to node %s index %u", sub_graph->GetName().c_str(), + parent_node->GetName().c_str(), index); + return ret; + } + return SUCCESS; +} + +Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_graph) { + if (arg.parent_node == nullptr) { + return SUCCESS; + } + std::string op_type = arg.parent_node->GetType(); + std::string op_name = arg.parent_node->GetName(); + domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr; + auto post_func = + domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type); + if (post_func == nullptr) { + GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str()); + if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc( op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) { + GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str()); + return SUCCESS; + } + } + + GELOGD("Post process for subgraph %s node %s type %s", arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), + arg.parent_node->GetType().c_str()); + + // refresh node_name in subgraph + for (const ge::NodePtr &node : sub_graph->GetDirectNode()) { + if ((node->GetOpDesc() == nullptr) || (node->GetType() == "Variable") || (node->GetType() == "VariableV2")) { + continue; + } + node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName()); + } + + auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph); + Status ret = FAILED; + if (post_func != nullptr) { + ret = post_func(arg.graph_name, graph); + } else if (parse_func_v2 != nullptr) { + ret = parse_func_v2(arg.graph_name.c_str(), graph); + } + if (ret != SUCCESS) { + GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", + arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), + arg.parent_node->GetType().c_str()); + return FAILED; + } + return SUCCESS; +} +} + +Status OnnxModelParser::ParseOutput(ge::onnx::GraphProto *onnx_graph) { + if (onnx_graph->output_size() == 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E16001"); + GELOGE(FAILED, "Onnx graph has zero output"); + return FAILED; + } + + // get output value info map + int64_t index = 0; + for (int i = 0; i < onnx_graph->output_size(); i++) { + ge::onnx::ValueInfoProto value_info = onnx_graph->output(i); + GELOGI("The index of %d output name : %s.", i, value_info.name().c_str()); + + ge::onnx::TensorProto tensor_tmp; + if (value_info.has_type()) { + const ge::onnx::TypeProto type = value_info.type(); + if (type.has_tensor_type()) { + const ge::onnx::TypeProto_Tensor type_proto_tensor = type.tensor_type(); + int32_t elem_type = type_proto_tensor.elem_type(); + tensor_tmp.set_data_type(elem_type); + if (type_proto_tensor.has_shape()) { + const ge::onnx::TensorShapeProto tensor_shape = type_proto_tensor.shape(); + for (int j = 0; j < tensor_shape.dim_size(); j++) { + const ge::onnx::TensorShapeProto_Dimension dimension = tensor_shape.dim(j); + int64_t dim_value = dimension.dim_value(); + tensor_tmp.add_dims(dim_value); + GELOGI("elem_type: %d, dim_value: %ld", elem_type, dim_value); + } + } + } + } + // Construct node for output + ge::onnx::NodeProto *output_node = onnx_graph->add_node(); + output_node->set_name(value_info.name()); + output_node->set_op_type(ge::kOpTypeOutput); + output_node->add_input(value_info.name()); + // add tensor + ge::onnx::AttributeProto *attribute = output_node->add_attribute(); + attribute->set_name(ge::kAttrNameOutput); + ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); + *attribute_tensor = tensor_tmp; + // add index + ge::onnx::AttributeProto *attribute_index = output_node->add_attribute(); + attribute_index->set_name(ge::kAttrNameIndex); + attribute_index->set_i(index++); + output_node_names_.emplace_back(value_info.name()); + } + return SUCCESS; } -Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, +Status OnnxModelParser::ParseInput(bool isSubgraph, ge::onnx::GraphProto *onnx_graph, std::map &initializer_name_tensor) { - if (onnx_graph.input_size() == 0) { + if (onnx_graph->input_size() == 0) { + if (!isSubgraph) { + ErrorManager::GetInstance().ATCReportErrMessage("E16001"); + GELOGE(FAILED, "Root onnx graph has zero input"); + return FAILED; + } else { + // sometime subgraph does not have input, we using constant as input nodes + for (int i = 0; i < onnx_graph->node_size(); i++) { + ge::onnx::NodeProto *node = onnx_graph->mutable_node(i); + if (node->op_type() == "Constant") { + input_node_names_.emplace_back(node->name()); + } + } + } + return SUCCESS; + } + + if (onnx_graph->input_size() == 0) { ErrorManager::GetInstance().ATCReportErrMessage("E16001"); - GELOGE(FAILED, "Onnx graph has zero input"); + GELOGE(FAILED, "onnx graph has zero input"); return FAILED; } // get input value info map int64_t data_index = 0; - for (int i = 0; i < onnx_graph.input_size(); i++) { - ge::onnx::ValueInfoProto value_info = onnx_graph.input(i); + for (int i = 0; i < onnx_graph->input_size(); i++) { + ge::onnx::ValueInfoProto value_info = onnx_graph->input(i); GELOGI("The index of %d input name : %s.", i, value_info.name().c_str()); /// if the input is initialized by a default value found in ‘initializer’, @@ -194,7 +360,7 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, } } // Construct node for input - ge::onnx::NodeProto *input_node = onnx_graph.add_node(); + ge::onnx::NodeProto *input_node = onnx_graph->add_node(); input_node->set_name(value_info.name()); input_node->set_op_type(ge::kOpTypeInput); input_node->add_output(value_info.name()); @@ -207,17 +373,22 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, ge::onnx::AttributeProto *attribute_index = input_node->add_attribute(); attribute_index->set_name(ge::kAttrNameIndex); attribute_index->set_i(data_index++); + // add subgraph attr + if (isSubgraph) { + attribute = input_node->add_attribute(); + attribute->set_name(ge::kAttrNameIsSubgraphOp); + } input_node_names_.emplace_back(value_info.name()); } return SUCCESS; } -Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, +Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto *onnx_graph, std::map &initializer_name_tensor) { // Construct const node for weight int index = 0; for (auto it : initializer_name_tensor) { - ge::onnx::NodeProto *const_node = onnx_graph.add_node(); + ge::onnx::NodeProto *const_node = onnx_graph->add_node(); std::string output_name = it.first + "_" + to_string(index++); const_node->set_name(output_name); const_node->set_op_type(ge::kOpTypeConstant); @@ -231,10 +402,10 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, return SUCCESS; } -Status OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) { +Status OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto *onnx_graph) { int index = 0; - for (int i = 0; i < onnx_graph.node_size(); i++) { - ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); + for (int i = 0; i < onnx_graph->node_size(); i++) { + ge::onnx::NodeProto *node = onnx_graph->mutable_node(i); if (node->name().empty()) { std::string node_name = node->op_type() + "_" + to_string(index++); node->set_name(node_name); @@ -318,9 +489,14 @@ Status OnnxModelParser::TransNodeToOperator(const ge::onnx::NodeProto *node_prot string node_name = node_proto->name(); op = ge::OperatorFactory::CreateOperator(node_name, op_type); if (op.GetName() != node_name) { - ErrorManager::GetInstance().ATCReportErrMessage("E16003", {"opname", "optype"}, {node_name, op_type}); - GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); - return INTERNAL_ERROR; + GELOGW("IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); + if (std::find(kMakeOperatorNotByIr.begin(), kMakeOperatorNotByIr.end(), op_type) != kMakeOperatorNotByIr.end()) { + op = ge::Operator(node_name, op_type); + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E16003", {"opname", "optype"}, {node_name, op_type}); + GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); + return INTERNAL_ERROR; + } } GELOGI("After create operator, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), @@ -381,18 +557,20 @@ Status OnnxModelParser::SetOperatorInputs() { GE_CHECK_NOTNULL(dst_op_desc); auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op); GE_CHECK_NOTNULL(src_op_desc); - dst_op.SetInput(dst_op_desc->GetInputNameByIndex(dst_index), src_op, - src_op_desc->GetOutputNameByIndex(src_index)); + std::string dst_name = dst_op_desc->GetInputNameByIndex(dst_index); + std::string src_name = src_op_desc->GetOutputNameByIndex(src_index); + GELOGI("dst_name:%s, src_name:%s.", dst_name.c_str(), src_name.c_str()); + dst_op.SetInput(dst_name, src_op, src_name); } } } return SUCCESS; } -Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) { +Status OnnxModelParser::Prechecker(ge::onnx::GraphProto *onnx_graph) { ge::PreChecker::Instance().Clear(); - for (int i = 0; i < onnx_graph.node_size(); i++) { - ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); + for (int i = 0; i < onnx_graph->node_size(); i++) { + ge::onnx::NodeProto *node = onnx_graph->mutable_node(i); std::string ori_type; Status ret = ConstructOriType(node, ori_type); if (ret != SUCCESS) { @@ -418,9 +596,9 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) { return SUCCESS; } -Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { - for (int i = 0; i < onnx_graph.node_size(); i++) { - ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); +Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto *onnx_graph, ge::Graph &graph) { + for (int i = 0; i < onnx_graph->node_size(); i++) { + ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i); std::string node_name = node_proto->name(); std::string ori_type = node_proto->op_type(); GELOGI("Start parse node which name is %s, type is %s", node_name.c_str(), ori_type.c_str()); @@ -455,6 +633,10 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: return status; } + GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), + op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize()); + + ge::graphStatus graph_status = graph.AddOp(op); if (graph_status != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "Add op:%s to graph failed.", op.GetName().c_str()); @@ -488,6 +670,21 @@ Status OnnxModelParser::GetGraphInputs(std::vector &input_ops) { return SUCCESS; } +Status OnnxModelParser::GetGraphOutputs(std::vector &output_ops) { + for (auto out_name : output_node_names_) { + auto out_op = name_operator_.find(out_name); + if (out_op == name_operator_.end()) { + GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.", + out_name.c_str()); + return PARAM_INVALID; + } + output_ops.emplace_back(out_op->second); + GELOGI("Model assigned output node name: %s", out_op->second.GetName().c_str()); + } + + return SUCCESS; +} + Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) { GE_CHECK_NOTNULL(file); GELOGI("File path is %s.", file); @@ -515,13 +712,56 @@ Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge:: return SUCCESS; } -Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) { +void OnnxModelParser::ClearMembers() { + name_operator_.clear(); + input_node_names_.clear(); + output_node_names_.clear(); + inputs_map_.clear(); + outputs_map_.clear(); +} + +Status OnnxModelParser::AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, + std::map &name_to_onnx_graph) { + std::deque onnx_graph_tasks; + int index = 0; + onnx_graph_tasks.push_back(&root_onnx_graph); + + while (!onnx_graph_tasks.empty()) { + ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front(); + onnx_graph_tasks.pop_front(); + for (int i = 0; i < onnx_graph->node_size(); i++) { + ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i); + if (node_proto->name().empty()) { + std::string node_name = node_proto->op_type() + "_" + to_string(index++); + node_proto->set_name(node_name); + } + if (node_proto->op_type() == "If") { + OnnxIfSubgraphAdapter adapter; + if (adapter.AdaptAndFindAllOnnxSubgraphs(node_proto, onnx_graph_tasks, name_to_onnx_graph) != SUCCESS) { + GELOGE(FAILED, "adapt subgraph failed."); + return FAILED; + } + } + } + } + return SUCCESS; +} + +Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph) { if (!onnx_model.has_graph()) { ErrorManager::GetInstance().ATCReportErrMessage("E16004"); GELOGE(PARAM_INVALID, "Onnx model do not has graph."); return FAILED; } - ge::onnx::GraphProto onnx_graph = onnx_model.graph(); + std::map name_to_onnx_graph; + std::deque tasks; + ge::onnx::GraphProto root_onnx_graph = onnx_model.graph(); + + auto ret = AdaptAndFindAllOnnxGraph(root_onnx_graph, name_to_onnx_graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "adapt subgraph failed."); + return FAILED; + } auto opset_import = onnx_model.opset_import(); for (auto it : opset_import) { @@ -529,10 +769,72 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version()); } + tasks.push_back({&root_onnx_graph, nullptr, "", 0}); + + while (!tasks.empty()) { + ParseArg arg = tasks.front(); + tasks.pop_front(); + bool isSubgraph = (arg.parent_node != nullptr) ? true : false; + + if (arg.onnx_graph == nullptr) { + auto itr = name_to_onnx_graph.find(arg.graph_name); + if (itr == name_to_onnx_graph.end()) { + GELOGE(FAILED, "can not find onnx graph, graph:%s.", arg.graph_name.c_str()); + return FAILED; + } + arg.onnx_graph = itr->second; + } + + ge::onnx::GraphProto *onnx_graph = arg.onnx_graph; + ge::Graph tmp_graph(arg.graph_name); + if (isSubgraph) { + ret = ModelParseToGraphImpl(isSubgraph, onnx_graph, tmp_graph); + } else { + ret = ModelParseToGraphImpl(isSubgraph, onnx_graph, root_graph); + } + if (ret != SUCCESS) { + GELOGE(ret, "model parse to graph impl failed."); + return ret; + } + + ge::ComputeGraphPtr cur_compute_graph; + if (isSubgraph) { + cur_compute_graph = ge::GraphUtils::GetComputeGraph(tmp_graph); + } else { + cur_compute_graph = ge::GraphUtils::GetComputeGraph(root_graph); + } + + ret = PostOpProcessForSubgraph(arg, cur_compute_graph); + if (ret != SUCCESS) { + GELOGE(ret, "PostOpProcessForSubgraph failed."); + return ret; + } + + ret = BuildLinkForSonAndParentGraph(cur_compute_graph, arg); + if (ret != SUCCESS) { + GELOGE(ret, "BuildLinkForSonAndParentGraph failed."); + return ret; + } + + ret = GenSubgraphParseTasks(cur_compute_graph, tasks); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to gen tasks on graph %s for next iteration", cur_compute_graph->GetName().c_str()); + return ret; + } + + } + UpdateDataFormat(root_graph); + return SUCCESS; +} + +Status OnnxModelParser::ModelParseToGraphImpl(bool isSubgraph, ge::onnx::GraphProto *onnx_graph, ge::Graph &graph) { + + ClearMembers(); + // 2. Get all inializer. std::map initializer_name_tensor; - for (int i = 0; i < onnx_graph.initializer_size(); i++) { - ge::onnx::TensorProto initializer_tensor = onnx_graph.initializer(i); + for (int i = 0; i < onnx_graph->initializer_size(); i++) { + ge::onnx::TensorProto initializer_tensor = onnx_graph->initializer(i); if (!initializer_tensor.name().empty()) { initializer_name_tensor[initializer_tensor.name()] = initializer_tensor; GELOGI("Initializer name: %s .", initializer_tensor.name().c_str()); @@ -541,7 +843,8 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model // 3. Parse Input from graph. GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); - Status ret = ParseInput(onnx_graph, initializer_name_tensor); + + Status ret = ParseInput(isSubgraph, onnx_graph, initializer_name_tensor); if (ret != SUCCESS) { GELOGE(ret, "Parse input for onnx failed."); return ret; @@ -555,6 +858,12 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model return ret; } + ret = ParseOutput(onnx_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Parse output for onnx failed."); + return ret; + } + // 5. Update node name for node do not has name. ret = UpdateAllNodeName(onnx_graph); if (ret != SUCCESS) { @@ -582,6 +891,10 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model return ret; } + std::vector op_names; + graph.GetAllOpName(op_names); + GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size()); + // 8. Set all operator input. ret = SetOperatorInputs(); if (ret != SUCCESS) { @@ -589,13 +902,8 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model return ret; } - std::vector op_names; - graph.GetAllOpName(op_names); - GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size()); - // 9. Construct graph. std::vector input_ops; - ret = GetGraphInputs(input_ops); if (ret != SUCCESS) { GELOGE(ret, "Get graph inputs failed."); @@ -604,7 +912,6 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model graph.SetInputs(input_ops); GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); - UpdateDataFormat(graph); GELOGI("Onnx model parser success."); return SUCCESS; diff --git a/parser/onnx/onnx_parser.h b/parser/onnx/onnx_parser.h index 45adf7c..759d016 100644 --- a/parser/onnx/onnx_parser.h +++ b/parser/onnx/onnx_parser.h @@ -34,6 +34,7 @@ #include #include #include +#include #include "external/register/register_error_codes.h" #include "omg/parser/model_parser.h" #include "omg/parser/op_parser.h" @@ -70,15 +71,17 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { } private: - Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); + Status ParseAllNodeProto(ge::onnx::GraphProto *onnx_graph, ge::Graph &graph); - Status ParseInput(ge::onnx::GraphProto &onnx_graph, + Status ParseInput(bool isSubgraph, ge::onnx::GraphProto *onnx_graph, std::map &initializer_name_tensor); - Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, + Status ParseOutput(ge::onnx::GraphProto *onnx_graph); + + Status ParseInitializer(ge::onnx::GraphProto *onnx_graph, std::map &initializer_name_tensor); - Status UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph); + Status UpdateAllNodeName(ge::onnx::GraphProto *onnx_graph); Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); @@ -92,7 +95,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { Status GetGraphInputs(std::vector &input_ops); - Status Prechecker(ge::onnx::GraphProto &onnx_graph); + Status GetGraphOutputs(std::vector &output_ops); + + Status Prechecker(ge::onnx::GraphProto *onnx_graph); Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model); @@ -100,8 +105,15 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph); + Status ModelParseToGraphImpl(bool isSubgraph, ge::onnx::GraphProto *onnx_graph, ge::Graph &graph); + void UpdateDataFormat(ge::Graph &graph); + void ClearMembers(); + + Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, + std::map &name_to_onnx_graph); + std::map ori_to_om_type_; std::map domain_verseion_; @@ -110,6 +122,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { std::vector input_node_names_; + std::vector output_node_names_; + std::map>> inputs_map_; std::map>> outputs_map_; diff --git a/parser/onnx/onnx_retval_parser.cc b/parser/onnx/onnx_retval_parser.cc new file mode 100644 index 0000000..e27ff0d --- /dev/null +++ b/parser/onnx/onnx_retval_parser.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/util.h" +#include "graph/debug/ge_attr_define.h" +#include "onnx_retval_parser.h" +#include "parser/common/op_parser_factory.h" +#include "parser/onnx/onnx_util.h" + +using domi::ONNX; +using namespace ge::parser; + +namespace ge { +Status OnnxRetvalParser::ParseParams(const Message *op_src, ge::Operator &op_def) { + GE_CHECK_NOTNULL(op_src); + const ge::onnx::NodeProto *node_src = reinterpret_cast(op_src); + GE_CHECK_NOTNULL(node_src); + GELOGD("Onnx op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op_type().c_str()); + + int64_t index = 0; + for (auto it : node_src->attribute()) { + if (it.name() == ge::kAttrNameIndex) { + index = it.i(); + GELOGI("The node has attribute with index: %ld", index); + } + } + op_def.SetAttr(ge::RETVAL_ATTR_NAME_INDEX, index); + ge::GeTensorDesc tensor_desc; + auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_def); + op_desc->AddInputDesc(tensor_desc); + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(ONNX, _RETVAL, OnnxRetvalParser); +} // namespace ge diff --git a/parser/onnx/onnx_retval_parser.h b/parser/onnx/onnx_retval_parser.h new file mode 100644 index 0000000..3022d77 --- /dev/null +++ b/parser/onnx/onnx_retval_parser.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_ONNX_ONNX_RETVAL_PARSER_H_ +#define GE_PARSER_ONNX_ONNX_RETVAL_PARSER_H_ + +#include "parser/onnx/onnx_op_parser.h" + +namespace ge { +class PARSER_FUNC_VISIBILITY OnnxRetvalParser : public OnnxOpParser { + public: + Status ParseParams(const Message *op_src, ge::Operator &op_def) override; +}; +} // namespace ge + +#endif // GE_PARSER_ONNX_ONNX_RETVAL_PARSER_H_ diff --git a/parser/onnx/onnx_subgraph_adapter.h b/parser/onnx/onnx_subgraph_adapter.h new file mode 100644 index 0000000..02acc65 --- /dev/null +++ b/parser/onnx/onnx_subgraph_adapter.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_ONNX_ONNX_SUBGRAPH_ADAPTER_H_ +#define GE_PARSER_ONNX_ONNX_SUBGRAPH_ADAPTER_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY _declspec(dllexport) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#endif + +#include +#include +#include "proto/onnx/ge_onnx.pb.h" +#include "external/register/register_error_codes.h" + +using Status = domi::Status; + +namespace ge { +class PARSER_FUNC_VISIBILITY OnnxSubgraphAdapter { + public: + /// @brief parse params + /// @param [in] parent_op parent op + /// @return SUCCESS parse success + /// @return FAILED Parse failed + virtual Status AdaptAndFindAllOnnxSubgraphs(ge::onnx::NodeProto *parent_op, + std::deque &onnx_graph_tasks, + std::map &name_to_onnx_graph) { + return domi::SUCCESS; + } +}; +} // namespace ge + +#endif // GE_PARSER_ONNX_ONNX_SUBGRAPH_ADAPTER_H_ diff --git a/parser/onnx/onnx_util.h b/parser/onnx/onnx_util.h index 259ed42..7ec9552 100644 --- a/parser/onnx/onnx_util.h +++ b/parser/onnx/onnx_util.h @@ -44,9 +44,12 @@ enum OnnxDataType { namespace ge { const char *const kAttrNameValue = "value"; const char *const kAttrNameInput = "input_tensor"; +const char *const kAttrNameOutput = "output_tensor"; const char *const kAttrNameIndex = "index"; +const char *const kAttrNameIsSubgraphOp = "isSubgraphOp"; const char *const kOpTypeConstant = "Constant"; const char *const kOpTypeInput = "Input"; +const char *const kOpTypeOutput = "Output"; class OnnxUtil { public: