From 50dcc53bab7bd3064618e30e29acbd8535c7b7bc Mon Sep 17 00:00:00 2001 From: baker Date: Thu, 10 Dec 2020 21:47:25 +0800 Subject: [PATCH] add onnx model parse api --- inc/external/parser/onnx_parser.h | 33 +++++ metadef | 2 +- parser/caffe/caffe_parser.h | 3 +- parser/common/acl_graph_parser_util.cc | 11 ++ parser/onnx/CMakeLists.txt | 67 ++++++++- parser/onnx/onnx_parser.cc | 197 +++++++++++++++++++------ parser/onnx/onnx_parser.h | 18 +-- parser/stub/gen_stubapi.py | 3 +- parser/tensorflow/tensorflow_parser.h | 3 +- 9 files changed, 275 insertions(+), 62 deletions(-) create mode 100644 inc/external/parser/onnx_parser.h diff --git a/inc/external/parser/onnx_parser.h b/inc/external/parser/onnx_parser.h new file mode 100644 index 0000000..20b6ebe --- /dev/null +++ b/inc/external/parser/onnx_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_ACL_GRAPH_ONNX_H_ +#define INC_EXTERNAL_ACL_GRAPH_ONNX_H_ + +#include "graph/ascend_string.h" +#include "graph/ge_error_codes.h" +#include "graph/graph.h" +#include "graph/types.h" + +namespace ge { +graphStatus aclgrphParseONNX(const char *model_file, + const std::map &parser_params, ge::Graph &graph); + +graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, + const std::map &parser_params, ge::Graph &graph); +} // namespace ge + +#endif // INC_EXTERNAL_ACL_GRAPH_ONNX_H_ diff --git a/metadef b/metadef index d19c9c5..dafb805 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit d19c9c5c92f21a0335c18681dcceed44f3a54ddc +Subproject commit dafb8059cad6ea43f1b3d6bfd0a6c80eb6dc8bbe diff --git a/parser/caffe/caffe_parser.h b/parser/caffe/caffe_parser.h index 3818067..363635b 100644 --- a/parser/caffe/caffe_parser.h +++ b/parser/caffe/caffe_parser.h @@ -54,7 +54,8 @@ class CaffeModelParser : public domi::ModelParser { */ Status Parse(const char *file, ge::Graph &graph) override; Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; - virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) { + + Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override { return domi::SUCCESS; } diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc index 0b683bd..a689f8e 100644 --- a/parser/common/acl_graph_parser_util.cc +++ b/parser/common/acl_graph_parser_util.cc @@ -67,8 +67,10 @@ const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\"" const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; static std::set kCaffeSupportInputFormatSet = {"NCHW", "ND"}; static std::set kTfSupportInputFormatSet = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"}; +static std::set kONNXSupportInputFormatSet = {"NCHW", "ND"}; const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; +const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model"; /// The maximum length of the file. /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 const int kMaxFileSizeLimit = INT_MAX; @@ -394,6 +396,15 @@ bool AclGrphParseUtil::CheckAclInputFormat(string &input_format) { {"input_format", input_format, kTFFormatSupport}); GELOGE(ge::FAILED, "Invalid value for input_format[%s], %s.", input_format.c_str(), kTFFormatSupport); return false; + } else if (ge::GetParserContext().type == domi::ONNX) { // onnx + if (kONNXSupportInputFormatSet.find(input_format) != kONNXSupportInputFormatSet.end()) { + return true; + } + // only support NCHW ND + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"input_format", input_format, kONNXFormatSupport}); + GELOGE(ge::FAILED, "Invalid value for input_format[%s], %s.", input_format.c_str(), kONNXFormatSupport); + return false; } return true; } diff --git a/parser/onnx/CMakeLists.txt b/parser/onnx/CMakeLists.txt index b365317..d40c06b 100644 --- a/parser/onnx/CMakeLists.txt +++ b/parser/onnx/CMakeLists.txt @@ -31,6 +31,7 @@ target_include_directories(fmk_onnx_parser PRIVATE ${PARSER_DIR} ${PARSER_DIR}/inc ${PARSER_DIR}/parser + ${PARSER_DIR}/parser/inc ${METADEF_DIR}/inc ${METADEF_DIR}/inc/graph ${METADEF_DIR}/inc/register @@ -67,7 +68,7 @@ target_link_libraries(fmk_onnx_parser PRIVATE ascend_protobuf register c_sec - parser_common + parser_common graph slog -Wl,--as-needed @@ -77,6 +78,66 @@ target_link_libraries(fmk_onnx_parser PRIVATE error_manager ) +################################################################## +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_onnx_parser.cc + COMMAND echo "Generating stub files." + && ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/../stub/gen_stubapi.py ${PARSER_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} + && mv onnx_parser.cc stub_onnx_parser.cc + && echo "Generating stub files end." + #WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + #DEPENDS ../stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} +) +################################################################## + +############ stub/libfmk_onnx_parser.so ############ +add_library(fmk_onnx_parser_stub SHARED + ${CMAKE_CURRENT_BINARY_DIR}/stub_onnx_parser.cc +) + +target_compile_options(fmk_onnx_parser_stub PRIVATE + -O2 +) + +target_compile_definitions(fmk_onnx_parser_stub PRIVATE + $<$,$>:FMK_SUPPORT_DUMP> + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + REUSE_MEMORY=1 + FMK_HOST_INFER + $<$:ONLY_COMPILE_OPEN_SRC> +) + +target_include_directories(fmk_onnx_parser_stub PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PARSER_DIR} + ${PARSER_DIR}/inc + ${PARSER_DIR}/inc/external + ${PARSER_DIR}/parser + ${PARSER_DIR}/../inc + ${PARSER_DIR}/../inc/common/util + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc/register + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/external/register + #### temp #### + ${PARSER_DIR}/../graphengine/inc/common/util + ${PARSER_DIR}/../graphengine/inc/external + ${PARSER_DIR}/../graphengine/inc/framework + ${PARSER_DIR}/../graphengine/inc + ${PARSER_DIR}/../graphengine/ge +) + +target_link_libraries(fmk_onnx_parser_stub PRIVATE + $ +) + +set_target_properties(fmk_onnx_parser_stub PROPERTIES + OUTPUT_NAME fmk_onnx_parser + LIBRARY_OUTPUT_DIRECTORY stub +) + ############ install ############ set(INSTALL_BASE_DIR "") set(INSTALL_LIBRARY_DIR lib) @@ -84,3 +145,7 @@ set(INSTALL_LIBRARY_DIR lib) install(TARGETS fmk_onnx_parser OPTIONAL LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} ) + +install(TARGETS fmk_onnx_parser_stub OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/stub +) \ No newline at end of file diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 102d96d..f3f78ee 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -19,9 +19,12 @@ #include #include "common/convert/pb2json.h" #include "common/util.h" +#include "common/ge_types.h" #include "common/util/error_manager/error_manager.h" #include "external/graph/operator_factory.h" #include "external/register/register_error_codes.h" +#include "external/parser/onnx_parser.h" +#include "external/ge/ge_api_types.h" #include "framework/omg/parser/parser_inner_ctx.h" #include "framework/omg/parser/parser_types.h" #include "omg/parser/parser_factory.h" @@ -35,6 +38,100 @@ #include "parser/onnx/onnx_util.h" #include "register/op_registry.h" +namespace ge { +graphStatus aclgrphParseONNX(const char *model_file, + const std::map &parser_params, ge::Graph &graph) { + GE_CHECK_NOTNULL(model_file); + GetParserContext().type = domi::ONNX; + std::map options; + options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); + + // load custom plugin so and proto + AclGrphParseUtil acl_graph_parse_util; + (void)acl_graph_parse_util.AclParserInitialize(options); + + string output_name; + if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Parser params before graph failed."); + return ge::FAILED; + } + // Create an empty computegraph + string graph_name = output_name.empty() ? "tmpGraph" : output_name; + ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared(graph_name); + GE_CHECK_NOTNULL(compute_graph); + + graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); + GE_CHECK_NOTNULL(model_parser); + + // parse caffe model_file to GE graph + ge::graphStatus ret = model_parser->Parse(model_file, graph); + if (ret != ge::SUCCESS) { + GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); + return ge::FAILED; + } + GELOGI("Parser graph %s success.", graph.GetName().c_str()); + + if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Parser params after graph failed."); + return ge::FAILED; + } + + if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); + return ge::FAILED; + } + GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); + return ge::SUCCESS; +} + +graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, + const std::map &parser_params, ge::Graph &graph) { + GE_CHECK_NOTNULL(buffer); + GetParserContext().type = domi::ONNX; + std::map options; + options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); + + // load custom plugin so and proto + AclGrphParseUtil acl_graph_parse_util; + (void)acl_graph_parse_util.AclParserInitialize(options); + + string output_name; + if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Parser params before graph failed."); + return ge::FAILED; + } + // Create an empty computegraph + string graph_name = output_name.empty() ? "tmpGraph" : output_name; + ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared(graph_name); + GE_CHECK_NOTNULL(compute_graph); + + graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); + GE_CHECK_NOTNULL(model_parser); + + // parse caffe model_file to GE graph + ge::graphStatus ret = model_parser->Parse(buffer, (uint32_t)size, graph); + if (ret != ge::SUCCESS) { + GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); + return ge::FAILED; + } + GELOGI("Parser graph %s success.", graph.GetName().c_str()); + + if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Parser params after graph failed."); + return ge::FAILED; + } + + if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); + return ge::FAILED; + } + GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); + return ge::SUCCESS; +} +} // namespace ge + namespace ge { namespace { std::map kOnnxOpMap = { @@ -107,27 +204,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, return SUCCESS; } -Status OnnxModelParser::ParseOutput(const ge::onnx::GraphProto &onnx_graph) { - if (onnx_graph.output_size() == 0) { - GELOGE(FAILED, "Onnx graph has zero output"); - return FAILED; - } - - for (int i = 0; i < onnx_graph.output_size(); i++) { - ge::onnx::ValueInfoProto value_info = onnx_graph.output(i); - GELOGI("The index of %d output name : %s.", i, value_info.name().c_str()); - - auto it = outputs_map_.find(value_info.name()); - if (it != outputs_map_.end()) { - std::string node_name = it->second[0].first; - output_node_names_.emplace_back(node_name); - GELOGI("Output node name: %s", node_name.c_str()); - } - } - - return SUCCESS; -} - Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, std::map &initializer_name_tensor) { // Construct const node for weight @@ -411,8 +487,7 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: return SUCCESS; } -Status OnnxModelParser::GetGraphInputsOutputs(std::vector &input_ops, - std::vector>> output_indexs) { +Status OnnxModelParser::GetGraphInputs(std::vector &input_ops) { for (auto in_name : input_node_names_) { auto in_op = name_operator_.find(in_name); if (in_op == name_operator_.end()) { @@ -424,31 +499,38 @@ Status OnnxModelParser::GetGraphInputsOutputs(std::vector &input_o GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); } - for (auto it : output_node_names_) { - auto out_op = name_operator_.find(it); - if (out_op == name_operator_.end()) { - GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.", - it.c_str()); - return PARAM_INVALID; - } - output_indexs.emplace_back(out_op->second, std::vector{}); - GELOGI("Model assigned output node name: %s", out_op->second.GetName().c_str()); - } return SUCCESS; } - -Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { +Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) { GE_CHECK_NOTNULL(file); GELOGI("File path is %s.", file); // 1. Get graph from onnx model file. ge::onnx::ModelProto onnx_model; - if (!ge::parser::ReadProtoFromBinaryFile(file, &onnx_model)) { + if (!ge::parser::ReadProtoFromBinaryFile(file, onnx_model)) { ErrorManager::GetInstance().ATCReportErrMessage( "E19021", {"reason"}, {"Read onnx model file failed."}); GELOGE(PARAM_INVALID, "Read onnx model file failed."); return FAILED; } + return SUCCESS; +} + +Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) { + GE_CHECK_NOTNULL(data); + + // 1. Get graph from onnx model file. + ge::onnx::ModelProto onnx_model; + if (!ge::parser::ReadProtoFromArray(file, onnx_model)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19021", {"reason"}, {"Read onnx model from memory failed."}); + GELOGE(PARAM_INVALID, "Read onnx model file failed."); + return FAILED; + } + return SUCCESS; +} + +Status OnnxModelParser::RealParse(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) { if (!onnx_model.has_graph()) { ErrorManager::GetInstance().ATCReportErrMessage("E16004"); GELOGE(PARAM_INVALID, "Onnx model do not has graph."); @@ -515,14 +597,7 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { return ret; } - // 8. Parse output from graph. - ret = ParseOutput(onnx_graph); - if (ret != SUCCESS) { - GELOGE(ret, "Parse output failed."); - return ret; - } - - // 9. Set all operator input. + // 8. Set all operator input. ret = SetOperatorInputs(); if (ret != SUCCESS) { GELOGE(ret, "Set operator input failed."); @@ -533,15 +608,14 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { graph.GetAllOpName(op_names); GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size()); - // 10. Construct graph. + // 9. Construct graph. std::vector input_ops; - std::vector>> output_indexs; - ret = GetGraphInputsOutputs(input_ops, output_indexs); + ret = GetGraphInputs(input_ops); if (ret != SUCCESS) { GELOGE(ret, "Get graph inputs and outputs failed."); return ret; } - graph.SetInputs(input_ops).SetOutputs(output_indexs); + graph.SetInputs(input_ops); GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); @@ -551,6 +625,33 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { return SUCCESS; } +Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { + ge::onnx::ModelProto onnx_model; + Status ret = GetModelFromFile(file, onnx_model); + if (ret != SUCCESS) { + GELOGE(FAILED, "get model from file failed."); + return FAILED; + } + ret = RealParse(onnx_model, graph) + if (ret != SUCCESS) { + GELOGE(FAILED, "parse model failed."); + return FAILED; + } +} + +Status OnnxModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) { + ge::onnx::ModelProto onnx_model; + Status ret = GetModelFromMemory(data, size, onnx_model); + if (ret != SUCCESS) { + GELOGE(FAILED, "get model from file failed."); + return FAILED; + } + ret = RealParse(onnx_model, graph) + if (ret != SUCCESS) { + GELOGE(FAILED, "parse model failed."); + return FAILED; + } +} Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) { if (model_file == nullptr) { GELOGE(FAILED, "Model file is nullptr."); diff --git a/parser/onnx/onnx_parser.h b/parser/onnx/onnx_parser.h index 5eba094..a81c36e 100644 --- a/parser/onnx/onnx_parser.h +++ b/parser/onnx/onnx_parser.h @@ -39,9 +39,8 @@ class OnnxModelParser : public domi::ModelParser { ge::DataType ConvertToGeDataType(const uint32_t type) override; Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override { return domi::SUCCESS; } - virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) { - return domi::SUCCESS; - } + + Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override; Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override { return domi::SUCCESS; @@ -62,8 +61,6 @@ class OnnxModelParser : public domi::ModelParser { Status ParseInput(ge::onnx::GraphProto &onnx_graph, std::map &initializer_name_tensor); - Status ParseOutput(const ge::onnx::GraphProto &onnx_graph); - Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, std::map &initializer_name_tensor); @@ -79,11 +76,16 @@ class OnnxModelParser : public domi::ModelParser { Status SetOperatorInputs(); - Status GetGraphInputsOutputs(std::vector &input_ops, - std::vector>> output_indexs); + Status GetGraphInputs(std::vector &input_ops); Status Prechecker(ge::onnx::GraphProto &onnx_graph); + Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model); + + Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model); + + Status RealParse(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph); + void UpdateFormat(ge::Graph &graph); std::map ori_to_om_type_; @@ -94,8 +96,6 @@ class OnnxModelParser : public domi::ModelParser { std::vector input_node_names_; - std::vector output_node_names_; - std::unordered_map>> inputs_map_; std::unordered_map>> outputs_map_; diff --git a/parser/stub/gen_stubapi.py b/parser/stub/gen_stubapi.py index c729e6f..b0252ae 100644 --- a/parser/stub/gen_stubapi.py +++ b/parser/stub/gen_stubapi.py @@ -71,7 +71,8 @@ max_code_len_per_line = 100 when DEBUG on """ white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h", - "ge_ir_build.h", "ge_api.h", "ge_prof.h", "tensorflow_parser.h", "caffe_parser.h"] + "ge_ir_build.h", "ge_api.h", "ge_prof.h", "tensorflow_parser.h", + "caffe_parser.h", "onnx_parser.h"] include_dir_key_words = ["ge", "graph", "parser"] DEBUG = True diff --git a/parser/tensorflow/tensorflow_parser.h b/parser/tensorflow/tensorflow_parser.h index c2eee31..94d6c3e 100644 --- a/parser/tensorflow/tensorflow_parser.h +++ b/parser/tensorflow/tensorflow_parser.h @@ -91,7 +91,8 @@ class TensorFlowModelParser : public domi::ModelParser { Status Parse(const char *file, ge::Graph &graph) override; Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; - virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) { + + Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override { return domi::SUCCESS; }