From 7f5a9d3508830c838be5b9704ccc1d0abbe26160 Mon Sep 17 00:00:00 2001 From: baker Date: Thu, 10 Dec 2020 22:50:38 +0800 Subject: [PATCH] add onnx model parse api --- parser/onnx/onnx_parser.cc | 81 ++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 38 deletions(-) diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 2196b86..fd49ac1 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -39,17 +39,14 @@ #include "register/op_registry.h" namespace ge { -graphStatus aclgrphParseONNX(const char *model_file, - const std::map &parser_params, ge::Graph &graph) { - GE_CHECK_NOTNULL(model_file); +graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, + const std::map &parser_params, + ge::Graph &graph, domi::ModelParser *model_parser) { GetParserContext().type = domi::ONNX; std::map options; options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); - // load custom plugin so and proto - AclGrphParseUtil acl_graph_parse_util; (void)acl_graph_parse_util.AclParserInitialize(options); - string output_name; if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { GELOGE(ge::FAILED, "Parser params before graph failed."); @@ -61,9 +58,39 @@ graphStatus aclgrphParseONNX(const char *model_file, GE_CHECK_NOTNULL(compute_graph); graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); - auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); + model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); GE_CHECK_NOTNULL(model_parser); + return ge::SUCCESS; +} + +graphStatus HandleAfterParse(AclGrphParseUtil &acl_graph_parse_util, + const std::map &parser_params, + ge::Graph &graph) { + if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Parser params after graph failed."); + return ge::FAILED; + } + + if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); + return ge::FAILED; + } + return ge::SUCCESS; +} +graphStatus aclgrphParseONNX(const char *model_file, + const std::map &parser_params, ge::Graph &graph) { + GE_CHECK_NOTNULL(model_file); + // load custom plugin so and proto + AclGrphParseUtil acl_graph_parse_util; + domi::ModelParser *model_parser = nullptr; + + if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { + GELOGE(ge::FAILED, "prepare before parse failed."); + return ge::FAILED; + } + + GE_CHECK_NOTNULL(model_parser); // parse caffe model_file to GE graph ge::graphStatus ret = model_parser->Parse(model_file, graph); if (ret != ge::SUCCESS) { @@ -72,15 +99,11 @@ graphStatus aclgrphParseONNX(const char *model_file, } GELOGI("Parser graph %s success.", graph.GetName().c_str()); - if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { - GELOGE(ge::FAILED, "Parser params after graph failed."); + if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { + GELOGE(ge::FAILED, "handle after parse failed."); return ge::FAILED; } - if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { - GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); - return ge::FAILED; - } GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); return ge::SUCCESS; } @@ -88,27 +111,14 @@ graphStatus aclgrphParseONNX(const char *model_file, graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, const std::map &parser_params, ge::Graph &graph) { GE_CHECK_NOTNULL(buffer); - GetParserContext().type = domi::ONNX; - std::map options; - options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); - // load custom plugin so and proto AclGrphParseUtil acl_graph_parse_util; - (void)acl_graph_parse_util.AclParserInitialize(options); + domi::ModelParser *model_parser = nullptr; - string output_name; - if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { - GELOGE(ge::FAILED, "Parser params before graph failed."); + if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { + GELOGE(ge::FAILED, "prepare before parse failed."); return ge::FAILED; } - // Create an empty computegraph - string graph_name = output_name.empty() ? "tmpGraph" : output_name; - ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared(graph_name); - GE_CHECK_NOTNULL(compute_graph); - - graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); - auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); - GE_CHECK_NOTNULL(model_parser); // parse caffe model_file to GE graph ge::graphStatus ret = model_parser->ParseFromMemory(buffer, (uint32_t)size, graph); @@ -118,17 +128,12 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, } GELOGI("Parser graph %s success.", graph.GetName().c_str()); - if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { - GELOGE(ge::FAILED, "Parser params after graph failed."); + if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { + GELOGE(ge::FAILED, "handle after parse failed."); return ge::FAILED; } - - if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { - GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); - return ge::FAILED; - } - GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); - return ge::SUCCESS; + GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); + return ge::SUCCESS; } } // namespace ge