Browse Source

add onnx model parse api

pull/134/head
baker 5 years ago
parent
commit
7f5a9d3508
1 changed files with 43 additions and 38 deletions
  1. +43
    -38
      parser/onnx/onnx_parser.cc

+ 43
- 38
parser/onnx/onnx_parser.cc View File

@@ -39,17 +39,14 @@
#include "register/op_registry.h"

namespace ge {
graphStatus aclgrphParseONNX(const char *model_file,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
GE_CHECK_NOTNULL(model_file);
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util,
const std::map<AscendString, AscendString> &parser_params,
ge::Graph &graph, domi::ModelParser *model_parser) {
GetParserContext().type = domi::ONNX;
std::map<string, string> options;
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX)));

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
(void)acl_graph_parse_util.AclParserInitialize(options);

string output_name;
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params before graph failed.");
@@ -61,9 +58,39 @@ graphStatus aclgrphParseONNX(const char *model_file,
GE_CHECK_NOTNULL(compute_graph);

graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX);
model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX);
GE_CHECK_NOTNULL(model_parser);
return ge::SUCCESS;
}

graphStatus HandleAfterParse(AclGrphParseUtil &acl_graph_parse_util,
const std::map<AscendString, AscendString> &parser_params,
ge::Graph &graph) {
if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params after graph failed.");
return ge::FAILED;
}

if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
return ge::SUCCESS;
}

graphStatus aclgrphParseONNX(const char *model_file,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
GE_CHECK_NOTNULL(model_file);
// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
domi::ModelParser *model_parser = nullptr;

if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) {
GELOGE(ge::FAILED, "prepare before parse failed.");
return ge::FAILED;
}

GE_CHECK_NOTNULL(model_parser);
// parse caffe model_file to GE graph
ge::graphStatus ret = model_parser->Parse(model_file, graph);
if (ret != ge::SUCCESS) {
@@ -72,15 +99,11 @@ graphStatus aclgrphParseONNX(const char *model_file,
}
GELOGI("Parser graph %s success.", graph.GetName().c_str());

if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params after graph failed.");
if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) {
GELOGE(ge::FAILED, "handle after parse failed.");
return ge::FAILED;
}

if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str());
return ge::SUCCESS;
}
@@ -88,27 +111,14 @@ graphStatus aclgrphParseONNX(const char *model_file,
graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size,
const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) {
GE_CHECK_NOTNULL(buffer);
GetParserContext().type = domi::ONNX;
std::map<string, string> options;
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX)));

// load custom plugin so and proto
AclGrphParseUtil acl_graph_parse_util;
(void)acl_graph_parse_util.AclParserInitialize(options);
domi::ModelParser *model_parser = nullptr;

string output_name;
if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params before graph failed.");
if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) {
GELOGE(ge::FAILED, "prepare before parse failed.");
return ge::FAILED;
}
// Create an empty computegraph
string graph_name = output_name.empty() ? "tmpGraph" : output_name;
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name);
GE_CHECK_NOTNULL(compute_graph);

graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX);
GE_CHECK_NOTNULL(model_parser);

// parse caffe model_file to GE graph
ge::graphStatus ret = model_parser->ParseFromMemory(buffer, (uint32_t)size, graph);
@@ -118,17 +128,12 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size,
}
GELOGI("Parser graph %s success.", graph.GetName().c_str());

if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Parser params after graph failed.");
if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) {
GELOGE(ge::FAILED, "handle after parse failed.");
return ge::FAILED;
}

if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str());
return ge::SUCCESS;
GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str());
return ge::SUCCESS;
}
} // namespace ge



Loading…
Cancel
Save