Browse Source

add onnx model parse api

pull/134/head
baker 5 years ago
parent
commit
50dcc53bab
9 changed files with 275 additions and 62 deletions
  1. +33
    -0
      inc/external/parser/onnx_parser.h
  2. +1
    -1
      metadef
  3. +2
    -1
      parser/caffe/caffe_parser.h
  4. +11
    -0
      parser/common/acl_graph_parser_util.cc
  5. +66
    -1
      parser/onnx/CMakeLists.txt
  6. +149
    -48
      parser/onnx/onnx_parser.cc
  7. +9
    -9
      parser/onnx/onnx_parser.h
  8. +2
    -1
      parser/stub/gen_stubapi.py
  9. +2
    -1
      parser/tensorflow/tensorflow_parser.h

+ 33
- 0
inc/external/parser/onnx_parser.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_ACL_GRAPH_ONNX_H_
#define INC_EXTERNAL_ACL_GRAPH_ONNX_H_

#include "graph/ascend_string.h"
#include "graph/ge_error_codes.h"
#include "graph/graph.h"
#include "graph/types.h"

namespace ge {
graphStatus aclgrphParseONNX(const char *model_file,
const std::map<ge::AscendString, ge::AscendString> &parser_params, ge::Graph &graph);

graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size,
const std::map<ge::AscendString, ge::AscendString> &parser_params, ge::Graph &graph);
} // namespace ge

#endif // INC_EXTERNAL_ACL_GRAPH_ONNX_H_

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit d19c9c5c92f21a0335c18681dcceed44f3a54ddc
Subproject commit dafb8059cad6ea43f1b3d6bfd0a6c80eb6dc8bbe

+ 2
- 1
parser/caffe/caffe_parser.h View File

@@ -54,7 +54,8 @@ class CaffeModelParser : public domi::ModelParser {
*/
Status Parse(const char *file, ge::Graph &graph) override;
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;
virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) {

Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override {
return domi::SUCCESS;
}



+ 11
- 0
parser/common/acl_graph_parser_util.cc View File

@@ -67,8 +67,10 @@ const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""
const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes.";
static std::set<std::string> kCaffeSupportInputFormatSet = {"NCHW", "ND"};
static std::set<std::string> kTfSupportInputFormatSet = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"};
static std::set<std::string> kONNXSupportInputFormatSet = {"NCHW", "ND"};
const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model";
const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model";
const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model";
/// The maximum length of the file.
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1
const int kMaxFileSizeLimit = INT_MAX;
@@ -394,6 +396,15 @@ bool AclGrphParseUtil::CheckAclInputFormat(string &input_format) {
{"input_format", input_format, kTFFormatSupport});
GELOGE(ge::FAILED, "Invalid value for input_format[%s], %s.", input_format.c_str(), kTFFormatSupport);
return false;
} else if (ge::GetParserContext().type == domi::ONNX) { // onnx
if (kONNXSupportInputFormatSet.find(input_format) != kONNXSupportInputFormatSet.end()) {
return true;
}
// only support NCHW ND
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"input_format", input_format, kONNXFormatSupport});
GELOGE(ge::FAILED, "Invalid value for input_format[%s], %s.", input_format.c_str(), kONNXFormatSupport);
return false;
}
return true;
}


+ 66
- 1
parser/onnx/CMakeLists.txt View File

@@ -31,6 +31,7 @@ target_include_directories(fmk_onnx_parser PRIVATE
${PARSER_DIR}
${PARSER_DIR}/inc
${PARSER_DIR}/parser
${PARSER_DIR}/parser/inc
${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/register
@@ -67,7 +68,7 @@ target_link_libraries(fmk_onnx_parser PRIVATE
ascend_protobuf
register
c_sec
parser_common
parser_common
graph
slog
-Wl,--as-needed
@@ -77,6 +78,66 @@ target_link_libraries(fmk_onnx_parser PRIVATE
error_manager
)

##################################################################
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_onnx_parser.cc
COMMAND echo "Generating stub files."
&& ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/../stub/gen_stubapi.py ${PARSER_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR}
&& mv onnx_parser.cc stub_onnx_parser.cc
&& echo "Generating stub files end."
#WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
#DEPENDS ../stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR}
)
##################################################################

############ stub/libfmk_onnx_parser.so ############
add_library(fmk_onnx_parser_stub SHARED
${CMAKE_CURRENT_BINARY_DIR}/stub_onnx_parser.cc
)

target_compile_options(fmk_onnx_parser_stub PRIVATE
-O2
)

target_compile_definitions(fmk_onnx_parser_stub PRIVATE
$<$<OR:$<STREQUAL:${PRODUCT_SIDE},host>,$<STREQUAL:${ENABLE_OPEN_SRC},True>>:FMK_SUPPORT_DUMP>
PROTOBUF_INLINE_NOT_IN_HEADERS=0
REUSE_MEMORY=1
FMK_HOST_INFER
$<$<STREQUAL:${ENABLE_OPEN_SRC},True>:ONLY_COMPILE_OPEN_SRC>
)

target_include_directories(fmk_onnx_parser_stub PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PARSER_DIR}
${PARSER_DIR}/inc
${PARSER_DIR}/inc/external
${PARSER_DIR}/parser
${PARSER_DIR}/../inc
${PARSER_DIR}/../inc/common/util
${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/register
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/external/register
#### temp ####
${PARSER_DIR}/../graphengine/inc/common/util
${PARSER_DIR}/../graphengine/inc/external
${PARSER_DIR}/../graphengine/inc/framework
${PARSER_DIR}/../graphengine/inc
${PARSER_DIR}/../graphengine/ge
)

target_link_libraries(fmk_onnx_parser_stub PRIVATE
$<BUILD_INTERFACE:intf_pub>
)

set_target_properties(fmk_onnx_parser_stub PROPERTIES
OUTPUT_NAME fmk_onnx_parser
LIBRARY_OUTPUT_DIRECTORY stub
)

############ install ############
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)
@@ -84,3 +145,7 @@ set(INSTALL_LIBRARY_DIR lib)
install(TARGETS fmk_onnx_parser OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}
)

install(TARGETS fmk_onnx_parser_stub OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/stub
)

+ 149
- 48
parser/onnx/onnx_parser.cc View File

@@ -19,9 +19,12 @@
#include <iostream>
#include "common/convert/pb2json.h"
#include "common/util.h"
#include "common/ge_types.h"
#include "common/util/error_manager/error_manager.h"
#include "external/graph/operator_factory.h"
#include "external/register/register_error_codes.h"
#include "external/parser/onnx_parser.h"
#include "external/ge/ge_api_types.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "framework/omg/parser/parser_types.h"
#include "omg/parser/parser_factory.h"
@@ -35,6 +38,100 @@
#include "parser/onnx/onnx_util.h"
#include "register/op_registry.h"

namespace ge {
graphStatus aclgrphParseONNX(const char *model_file,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
GE_CHECK_NOTNULL(model_file);
GetParserContext().type = domi::ONNX;
std::map<string, string> options;
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX)));

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
(void)acl_graph_parse_util.AclParserInitialize(options);

string output_name;
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params before graph failed.");
return ge::FAILED;
}
// Create an empty computegraph
string graph_name = output_name.empty() ? "tmpGraph" : output_name;
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name);
GE_CHECK_NOTNULL(compute_graph);

graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX);
GE_CHECK_NOTNULL(model_parser);

// parse caffe model_file to GE graph
ge::graphStatus ret = model_parser->Parse(model_file, graph);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("Parser graph %s success.", graph.GetName().c_str());

if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params after graph failed.");
return ge::FAILED;
}

if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str());
return ge::SUCCESS;
}

graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
GE_CHECK_NOTNULL(buffer);
GetParserContext().type = domi::ONNX;
std::map<string, string> options;
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX)));

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
(void)acl_graph_parse_util.AclParserInitialize(options);

string output_name;
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params before graph failed.");
return ge::FAILED;
}
// Create an empty computegraph
string graph_name = output_name.empty() ? "tmpGraph" : output_name;
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name);
GE_CHECK_NOTNULL(compute_graph);

graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX);
GE_CHECK_NOTNULL(model_parser);

// parse caffe model_file to GE graph
ge::graphStatus ret = model_parser->Parse(buffer, (uint32_t)size, graph);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("Parser graph %s success.", graph.GetName().c_str());

if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params after graph failed.");
return ge::FAILED;
}

if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str());
return ge::SUCCESS;
}
} // namespace ge

namespace ge {
namespace {
std::map<std::string, std::string> kOnnxOpMap = {
@@ -107,27 +204,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph,
return SUCCESS;
}

Status OnnxModelParser::ParseOutput(const ge::onnx::GraphProto &onnx_graph) {
if (onnx_graph.output_size() == 0) {
GELOGE(FAILED, "Onnx graph has zero output");
return FAILED;
}

for (int i = 0; i < onnx_graph.output_size(); i++) {
ge::onnx::ValueInfoProto value_info = onnx_graph.output(i);
GELOGI("The index of %d output name : %s.", i, value_info.name().c_str());

auto it = outputs_map_.find(value_info.name());
if (it != outputs_map_.end()) {
std::string node_name = it->second[0].first;
output_node_names_.emplace_back(node_name);
GELOGI("Output node name: %s", node_name.c_str());
}
}

return SUCCESS;
}

Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) {
// Construct const node for weight
@@ -411,8 +487,7 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
return SUCCESS;
}

Status OnnxModelParser::GetGraphInputsOutputs(std::vector<ge::Operator> &input_ops,
std::vector<std::pair<ge::Operator, std::vector<size_t>>> output_indexs) {
Status OnnxModelParser::GetGraphInputs(std::vector<ge::Operator> &input_ops) {
for (auto in_name : input_node_names_) {
auto in_op = name_operator_.find(in_name);
if (in_op == name_operator_.end()) {
@@ -424,31 +499,38 @@ Status OnnxModelParser::GetGraphInputsOutputs(std::vector<ge::Operator> &input_o
GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str());
}

for (auto it : output_node_names_) {
auto out_op = name_operator_.find(it);
if (out_op == name_operator_.end()) {
GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.",
it.c_str());
return PARAM_INVALID;
}
output_indexs.emplace_back(out_op->second, std::vector<size_t>{});
GELOGI("Model assigned output node name: %s", out_op->second.GetName().c_str());
}
return SUCCESS;
}

Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) {
GE_CHECK_NOTNULL(file);
GELOGI("File path is %s.", file);

// 1. Get graph from onnx model file.
ge::onnx::ModelProto onnx_model;
if (!ge::parser::ReadProtoFromBinaryFile(file, &onnx_model)) {
if (!ge::parser::ReadProtoFromBinaryFile(file, onnx_model)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19021", {"reason"}, {"Read onnx model file failed."});
GELOGE(PARAM_INVALID, "Read onnx model file failed.");
return FAILED;
}
return SUCCESS;
}

Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) {
GE_CHECK_NOTNULL(data);

// 1. Get graph from onnx model file.
ge::onnx::ModelProto onnx_model;
if (!ge::parser::ReadProtoFromArray(file, onnx_model)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19021", {"reason"}, {"Read onnx model from memory failed."});
GELOGE(PARAM_INVALID, "Read onnx model file failed.");
return FAILED;
}
return SUCCESS;
}

Status OnnxModelParser::RealParse(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) {
if (!onnx_model.has_graph()) {
ErrorManager::GetInstance().ATCReportErrMessage("E16004");
GELOGE(PARAM_INVALID, "Onnx model do not has graph.");
@@ -515,14 +597,7 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
return ret;
}

// 8. Parse output from graph.
ret = ParseOutput(onnx_graph);
if (ret != SUCCESS) {
GELOGE(ret, "Parse output failed.");
return ret;
}

// 9. Set all operator input.
// 8. Set all operator input.
ret = SetOperatorInputs();
if (ret != SUCCESS) {
GELOGE(ret, "Set operator input failed.");
@@ -533,15 +608,14 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
graph.GetAllOpName(op_names);
GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size());

// 10. Construct graph.
// 9. Construct graph.
std::vector<ge::Operator> input_ops;
std::vector<std::pair<ge::Operator, std::vector<size_t>>> output_indexs;
ret = GetGraphInputsOutputs(input_ops, output_indexs);
ret = GetGraphInputs(input_ops);
if (ret != SUCCESS) {
GELOGE(ret, "Get graph inputs and outputs failed.");
return ret;
}
graph.SetInputs(input_ops).SetOutputs(output_indexs);
graph.SetInputs(input_ops);

GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph));

@@ -551,6 +625,33 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
return SUCCESS;
}

Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
ge::onnx::ModelProto onnx_model;
Status ret = GetModelFromFile(file, onnx_model);
if (ret != SUCCESS) {
GELOGE(FAILED, "get model from file failed.");
return FAILED;
}
ret = RealParse(onnx_model, graph)
if (ret != SUCCESS) {
GELOGE(FAILED, "parse model failed.");
return FAILED;
}
}

Status OnnxModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) {
ge::onnx::ModelProto onnx_model;
Status ret = GetModelFromMemory(data, size, onnx_model);
if (ret != SUCCESS) {
GELOGE(FAILED, "get model from file failed.");
return FAILED;
}
ret = RealParse(onnx_model, graph)
if (ret != SUCCESS) {
GELOGE(FAILED, "parse model failed.");
return FAILED;
}
}
Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) {
if (model_file == nullptr) {
GELOGE(FAILED, "Model file is nullptr.");


+ 9
- 9
parser/onnx/onnx_parser.h View File

@@ -39,9 +39,8 @@ class OnnxModelParser : public domi::ModelParser {
ge::DataType ConvertToGeDataType(const uint32_t type) override;

Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override { return domi::SUCCESS; }
virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) {
return domi::SUCCESS;
}

Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override;

Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override {
return domi::SUCCESS;
@@ -62,8 +61,6 @@ class OnnxModelParser : public domi::ModelParser {
Status ParseInput(ge::onnx::GraphProto &onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor);

Status ParseOutput(const ge::onnx::GraphProto &onnx_graph);

Status ParseInitializer(ge::onnx::GraphProto &onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor);

@@ -79,11 +76,16 @@ class OnnxModelParser : public domi::ModelParser {

Status SetOperatorInputs();

Status GetGraphInputsOutputs(std::vector<ge::Operator> &input_ops,
std::vector<std::pair<ge::Operator, std::vector<size_t>>> output_indexs);
Status GetGraphInputs(std::vector<ge::Operator> &input_ops);

Status Prechecker(ge::onnx::GraphProto &onnx_graph);

Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model);

Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model);

Status RealParse(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph);

void UpdateFormat(ge::Graph &graph);

std::map<std::string, std::string> ori_to_om_type_;
@@ -94,8 +96,6 @@ class OnnxModelParser : public domi::ModelParser {

std::vector<std::string> input_node_names_;

std::vector<std::string> output_node_names_;

std::unordered_map<std::string, std::vector<std::pair<std::string, int>>> inputs_map_;

std::unordered_map<std::string, std::vector<std::pair<std::string, int>>> outputs_map_;


+ 2
- 1
parser/stub/gen_stubapi.py View File

@@ -71,7 +71,8 @@ max_code_len_per_line = 100
when DEBUG on
"""
white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h",
"ge_ir_build.h", "ge_api.h", "ge_prof.h", "tensorflow_parser.h", "caffe_parser.h"]
"ge_ir_build.h", "ge_api.h", "ge_prof.h", "tensorflow_parser.h",
"caffe_parser.h", "onnx_parser.h"]
include_dir_key_words = ["ge", "graph", "parser"]
DEBUG = True



+ 2
- 1
parser/tensorflow/tensorflow_parser.h View File

@@ -91,7 +91,8 @@ class TensorFlowModelParser : public domi::ModelParser {
Status Parse(const char *file, ge::Graph &graph) override;

Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;
virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) {

Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override {
return domi::SUCCESS;
}



Loading…
Cancel
Save