Browse Source

!153 add onnx model parse api

Merge pull request !153 from 陈华/development
pull/153/MERGE
i-robot Gitee 5 years ago
parent
commit
91b40bced8
8 changed files with 333 additions and 62 deletions
  1. +34
    -0
      inc/external/parser/onnx_parser.h
  2. +14
    -1
      parser/caffe/caffe_parser.h
  3. +56
    -0
      parser/onnx/CMakeLists.txt
  4. +33
    -1
      parser/onnx/module.mk
  5. +169
    -48
      parser/onnx/onnx_parser.cc
  6. +13
    -9
      parser/onnx/onnx_parser.h
  7. +1
    -2
      parser/stub/gen_stubapi.py
  8. +13
    -1
      parser/tensorflow/tensorflow_parser.h

+ 34
- 0
inc/external/parser/onnx_parser.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_PARSER_ONNX_PARSER_H_
#define INC_EXTERNAL_PARSER_ONNX_PARSER_H_

#include "graph/ascend_string.h"
#include "graph/ge_error_codes.h"
#include "graph/graph.h"
#include "graph/types.h"

namespace ge {
graphStatus aclgrphParseONNX(const char *model_file,
const std::map<ge::AscendString, ge::AscendString> &parser_params, ge::Graph &graph);

graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size,
const std::map<ge::AscendString, ge::AscendString> &parser_params,
ge::Graph &graph);
} // namespace ge

#endif // INC_EXTERNAL_PARSER_ONNX_PARSER_H_

+ 14
- 1
parser/caffe/caffe_parser.h View File

@@ -53,10 +53,23 @@ class CaffeModelParser : public domi::ModelParser {
* @return FAILED parse failed
*/
Status Parse(const char *file, ge::Graph &graph) override;

/**
* @ingroup domi_omg
* @brief Parse the relevant data from memory and save it to graph
* @param [in] memory buffer of model file
* @param [in] buffer size
* @param [in|out] graph graph for saving model information
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;
virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) {

#ifndef ONLY_COMPILE_OPEN_SRC
Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override {
return domi::SUCCESS;
}
#endif

/**
* @ingroup domi_omg


+ 56
- 0
parser/onnx/CMakeLists.txt View File

@@ -31,6 +31,7 @@ target_include_directories(fmk_onnx_parser PRIVATE
${PARSER_DIR}
${PARSER_DIR}/inc
${PARSER_DIR}/parser
${PARSER_DIR}/parser/inc
${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/register
@@ -77,6 +78,57 @@ target_link_libraries(fmk_onnx_parser PRIVATE
error_manager
)

##################################################################
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_onnx_parser.cc
COMMAND echo "Generating stub files."
&& ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/../stub/gen_stubapi.py ${PARSER_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR}
&& mv onnx_parser.cc stub_onnx_parser.cc
&& echo "Generating stub files end."
#WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
#DEPENDS ../stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR}
)
##################################################################

############ stub/libfmk_onnx_parser.so ############
add_library(fmk_onnx_parser_stub SHARED
${CMAKE_CURRENT_BINARY_DIR}/stub_onnx_parser.cc
)

target_compile_options(fmk_onnx_parser_stub PRIVATE
-O2
)

target_compile_definitions(fmk_onnx_parser_stub PRIVATE
$<$<OR:$<STREQUAL:${PRODUCT_SIDE},host>,$<STREQUAL:${ENABLE_OPEN_SRC},True>>:FMK_SUPPORT_DUMP>
PROTOBUF_INLINE_NOT_IN_HEADERS=0
REUSE_MEMORY=1
FMK_HOST_INFER
$<$<STREQUAL:${ENABLE_OPEN_SRC},True>:ONLY_COMPILE_OPEN_SRC>
)

target_include_directories(fmk_onnx_parser_stub PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PARSER_DIR}
${PARSER_DIR}/inc
${PARSER_DIR}/inc/external
${PARSER_DIR}/parser
${PARSER_DIR}/../inc
${METADEF_DIR}/inc
${METADEF_DIR}/inc/graph
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
)

target_link_libraries(fmk_onnx_parser_stub PRIVATE
$<BUILD_INTERFACE:intf_pub>
)

set_target_properties(fmk_onnx_parser_stub PROPERTIES
OUTPUT_NAME fmk_onnx_parser
LIBRARY_OUTPUT_DIRECTORY stub
)

############ install ############
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)
@@ -84,3 +136,7 @@ set(INSTALL_LIBRARY_DIR lib)
install(TARGETS fmk_onnx_parser OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}
)

install(TARGETS fmk_onnx_parser_stub OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/stub
)

+ 33
- 1
parser/onnx/module.mk View File

@@ -1,6 +1,6 @@

LOCAL_PATH := $(call my-dir)
include $(LOCAL_PATH)/../stub/Makefile
include $(CLEAR_VARS)

LOCAL_MODULE := libfmk_onnx_parser
@@ -57,3 +57,35 @@ LOCAL_STATIC_LIBRARIES += libmmpa
LOCAL_LDFLAGS := -lrt -ldl

include $(BUILD_HOST_SHARED_LIBRARY)

#compiler for host parser
include $(CLEAR_VARS)

LOCAL_C_INCLUDES := \
$(TOPDIR)inc \
$(TOPDIR)metadef/inc \
$(TOPDIR)parser/inc \
$(TOPDIR)inc/external \
$(TOPDIR)metadef/inc/external \
$(TOPDIR)parser/inc/external \
$(TOPDIR)metadef/inc/external/graph \
libc_sec/include \

LOCAL_MODULE := stub/libfmk_parser

LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2
LOCAL_CFLAGS += -DFMK_HOST_INFER -DFMK_SUPPORT_DUMP
ifeq ($(DEBUG), 1)
LOCAL_CFLAGS += -g -O0
endif

LOCAL_C_INCLUDES := $(LOCAL_C_INCLUDES)

LOCAL_SRC_FILES := ../../../out/parser/lib64/stub/onnx_parser.cc

LOCAL_SHARED_LIBRARIES :=

LOCAL_LDFLAGS := -lrt -ldl

include $(BUILD_HOST_SHARED_LIBRARY)


+ 169
- 48
parser/onnx/onnx_parser.cc View File

@@ -19,9 +19,12 @@
#include <iostream>
#include "common/convert/pb2json.h"
#include "common/util.h"
#include "common/ge_types.h"
#include "common/util/error_manager/error_manager.h"
#include "external/graph/operator_factory.h"
#include "external/register/register_error_codes.h"
#include "external/parser/onnx_parser.h"
#include "external/ge/ge_api_types.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "framework/omg/parser/parser_types.h"
#include "omg/parser/parser_factory.h"
@@ -35,6 +38,113 @@
#include "parser/onnx/onnx_util.h"
#include "register/op_registry.h"

namespace ge {
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util,
const std::map<AscendString, AscendString> &parser_params,
ge::Graph &graph, std::shared_ptr<domi::ModelParser> &model_parser) {
GetParserContext().type = domi::ONNX;
std::map<string, string> options;
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX)));

if (acl_graph_parse_util.AclParserInitialize(options) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Acl parser initialize failed.");
return ge::FAILED;
}

string output_name;
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params before graph failed.");
return ge::FAILED;
}
// Create an empty computegraph
string graph_name = output_name.empty() ? "tmpGraph" : output_name;
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name);
GE_CHECK_NOTNULL(compute_graph);

graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX);
GE_CHECK_NOTNULL(model_parser);
return ge::SUCCESS;
}

graphStatus HandleAfterParse(AclGrphParseUtil &acl_graph_parse_util,
const std::map<AscendString, AscendString> &parser_params,
ge::Graph &graph) {
if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params after graph failed.");
return ge::FAILED;
}

if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
return ge::SUCCESS;
}

graphStatus aclgrphParseONNX(const char *model_file,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
#ifndef ONLY_COMPILE_OPEN_SRC
GE_CHECK_NOTNULL(model_file);
// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
std::shared_ptr<domi::ModelParser> model_parser;

if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Prepare before parse failed.");
return ge::FAILED;
}

GE_CHECK_NOTNULL(model_parser);
// parse caffe model_file to GE graph
ge::graphStatus ret = model_parser->Parse(model_file, graph);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("Parser graph %s success.", graph.GetName().c_str());

if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Handle after parse failed.");
return ge::FAILED;
}

GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str());
#endif
return ge::SUCCESS;
}

graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
#ifndef ONLY_COMPILE_OPEN_SRC
GE_CHECK_NOTNULL(buffer);
// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
std::shared_ptr<domi::ModelParser> model_parser;

if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Prepare before parse failed.");
return ge::FAILED;
}

// parse caffe model_file to GE graph
ge::graphStatus ret = model_parser->ParseFromMemory(buffer, (uint32_t)size, graph);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("Parser graph %s success.", graph.GetName().c_str());

if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Handle after parse failed.");
return ge::FAILED;
}
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str());
#endif
return ge::SUCCESS;
}
} // namespace ge

namespace ge {
namespace {
std::map<std::string, std::string> kOnnxOpMap = {
@@ -107,27 +217,6 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph,
return SUCCESS;
}

Status OnnxModelParser::ParseOutput(const ge::onnx::GraphProto &onnx_graph) {
if (onnx_graph.output_size() == 0) {
GELOGE(FAILED, "Onnx graph has zero output");
return FAILED;
}

for (int i = 0; i < onnx_graph.output_size(); i++) {
ge::onnx::ValueInfoProto value_info = onnx_graph.output(i);
GELOGI("The index of %d output name : %s.", i, value_info.name().c_str());

auto it = outputs_map_.find(value_info.name());
if (it != outputs_map_.end()) {
std::string node_name = it->second[0].first;
output_node_names_.emplace_back(node_name);
GELOGI("Output node name: %s", node_name.c_str());
}
}

return SUCCESS;
}

Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) {
// Construct const node for weight
@@ -411,8 +500,7 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
return SUCCESS;
}

Status OnnxModelParser::GetGraphInputsOutputs(std::vector<ge::Operator> &input_ops,
std::vector<std::pair<ge::Operator, std::vector<size_t>>> output_indexs) {
Status OnnxModelParser::GetGraphInputs(std::vector<ge::Operator> &input_ops) {
for (auto in_name : input_node_names_) {
auto in_op = name_operator_.find(in_name);
if (in_op == name_operator_.end()) {
@@ -424,31 +512,39 @@ Status OnnxModelParser::GetGraphInputsOutputs(std::vector<ge::Operator> &input_o
GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str());
}

for (auto it : output_node_names_) {
auto out_op = name_operator_.find(it);
if (out_op == name_operator_.end()) {
GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.",
it.c_str());
return PARAM_INVALID;
}
output_indexs.emplace_back(out_op->second, std::vector<size_t>{});
GELOGI("Model assigned output node name: %s", out_op->second.GetName().c_str());
}
return SUCCESS;
}

Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
Status OnnxModelParser::GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model) {
GE_CHECK_NOTNULL(file);
GELOGI("File path is %s.", file);

// 1. Get graph from onnx model file.
ge::onnx::ModelProto onnx_model;
if (!ge::parser::ReadProtoFromBinaryFile(file, &onnx_model)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19021", {"reason"}, {"Read onnx model file failed."});
GELOGE(PARAM_INVALID, "Read onnx model file failed.");
return FAILED;
}
return SUCCESS;
}

#ifndef ONLY_COMPILE_OPEN_SRC
Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model) {
GE_CHECK_NOTNULL(data);

// 1. Get graph from onnx model file.
if (!ge::parser::ReadProtoFromArray(data, size, &onnx_model)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19021", {"reason"}, {"Read onnx model from memory failed."});
GELOGE(PARAM_INVALID, "Read onnx model from memory failed.");
return FAILED;
}
return SUCCESS;
}
#endif

Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) {
if (!onnx_model.has_graph()) {
ErrorManager::GetInstance().ATCReportErrMessage("E16004");
GELOGE(PARAM_INVALID, "Onnx model do not has graph.");
@@ -515,14 +611,7 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
return ret;
}

// 8. Parse output from graph.
ret = ParseOutput(onnx_graph);
if (ret != SUCCESS) {
GELOGE(ret, "Parse output failed.");
return ret;
}

// 9. Set all operator input.
// 8. Set all operator input.
ret = SetOperatorInputs();
if (ret != SUCCESS) {
GELOGE(ret, "Set operator input failed.");
@@ -533,15 +622,15 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
graph.GetAllOpName(op_names);
GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size());

// 10. Construct graph.
// 9. Construct graph.
std::vector<ge::Operator> input_ops;
std::vector<std::pair<ge::Operator, std::vector<size_t>>> output_indexs;
ret = GetGraphInputsOutputs(input_ops, output_indexs);
ret = GetGraphInputs(input_ops);
if (ret != SUCCESS) {
GELOGE(ret, "Get graph inputs and outputs failed.");
GELOGE(ret, "Get graph inputs failed.");
return ret;
}
graph.SetInputs(input_ops).SetOutputs(output_indexs);
graph.SetInputs(input_ops);

GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph));

@@ -551,6 +640,38 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
return SUCCESS;
}

Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
ge::onnx::ModelProto onnx_model;
Status ret = GetModelFromFile(file, onnx_model);
if (ret != SUCCESS) {
GELOGE(FAILED, "get model from file failed.");
return FAILED;
}
ret = ModelParseToGraph(onnx_model, graph);
if (ret != SUCCESS) {
GELOGE(FAILED, "parse model failed.");
return FAILED;
}
return SUCCESS;
}

#ifndef ONLY_COMPILE_OPEN_SRC
Status OnnxModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) {
ge::onnx::ModelProto onnx_model;
Status ret = GetModelFromMemory(data, size, onnx_model);
if (ret != SUCCESS) {
GELOGE(FAILED, "get model from file failed.");
return FAILED;
}
ret = ModelParseToGraph(onnx_model, graph);
if (ret != SUCCESS) {
GELOGE(FAILED, "parse model failed.");
return FAILED;
}
return SUCCESS;
}
#endif

Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) {
if (model_file == nullptr) {
GELOGE(FAILED, "Model file is nullptr.");


+ 13
- 9
parser/onnx/onnx_parser.h View File

@@ -39,9 +39,10 @@ class OnnxModelParser : public domi::ModelParser {
ge::DataType ConvertToGeDataType(const uint32_t type) override;

Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override { return domi::SUCCESS; }
virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) {
return domi::SUCCESS;
}

#ifndef ONLY_COMPILE_OPEN_SRC
Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override;
#endif

Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override {
return domi::SUCCESS;
@@ -62,8 +63,6 @@ class OnnxModelParser : public domi::ModelParser {
Status ParseInput(ge::onnx::GraphProto &onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor);

Status ParseOutput(const ge::onnx::GraphProto &onnx_graph);

Status ParseInitializer(ge::onnx::GraphProto &onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor);

@@ -79,10 +78,17 @@ class OnnxModelParser : public domi::ModelParser {

Status SetOperatorInputs();

Status GetGraphInputsOutputs(std::vector<ge::Operator> &input_ops,
std::vector<std::pair<ge::Operator, std::vector<size_t>>> output_indexs);
Status GetGraphInputs(std::vector<ge::Operator> &input_ops);

Status Prechecker(ge::onnx::GraphProto &onnx_graph);
Status GetModelFromFile(const char *file, ge::onnx::ModelProto &onnx_model);

#ifndef ONLY_COMPILE_OPEN_SRC
Status GetModelFromMemory(const char *data, uint32_t size, ge::onnx::ModelProto &onnx_model);
#endif

Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph);

void UpdateFormat(ge::Graph &graph);

@@ -94,8 +100,6 @@ class OnnxModelParser : public domi::ModelParser {

std::vector<std::string> input_node_names_;

std::vector<std::string> output_node_names_;

std::unordered_map<std::string, std::vector<std::pair<std::string, int>>> inputs_map_;

std::unordered_map<std::string, std::vector<std::pair<std::string, int>>> outputs_map_;


+ 1
- 2
parser/stub/gen_stubapi.py View File

@@ -70,8 +70,7 @@ max_code_len_per_line = 100
determines which header files to generate cc files from
when DEBUG on
"""
white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h",
"ge_ir_build.h", "ge_api.h", "tensorflow_parser.h", "caffe_parser.h"]
white_list_for_debug = ["tensorflow_parser.h", "caffe_parser.h", "onnx_parser.h"]
include_dir_key_words = ["ge", "graph", "parser"]
DEBUG = True



+ 13
- 1
parser/tensorflow/tensorflow_parser.h View File

@@ -90,10 +90,22 @@ class TensorFlowModelParser : public domi::ModelParser {
*/
Status Parse(const char *file, ge::Graph &graph) override;

/**
* @ingroup domi_omg
* @brief Parse the relevant data from memory and save it to graph
* @param [in] memory buffer of model file
* @param [in] buffer size
* @param [in|out] graph graph for saving model information
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;
virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) {

#ifndef ONLY_COMPILE_OPEN_SRC
Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override {
return domi::SUCCESS;
}
#endif

/**
* @ingroup domi_omg


Loading…
Cancel
Save