| @@ -39,17 +39,14 @@ | |||
| #include "register/op_registry.h" | |||
| namespace ge { | |||
| graphStatus aclgrphParseONNX(const char *model_file, | |||
| const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | |||
| GE_CHECK_NOTNULL(model_file); | |||
| graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | |||
| const std::map<AscendString, AscendString> &parser_params, | |||
| ge::Graph &graph, domi::ModelParser *model_parser) { | |||
| GetParserContext().type = domi::ONNX; | |||
| std::map<string, string> options; | |||
| options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); | |||
| // load custom plugin so and proto | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| (void)acl_graph_parse_util.AclParserInitialize(options); | |||
| string output_name; | |||
| if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params before graph failed."); | |||
| @@ -61,9 +58,39 @@ graphStatus aclgrphParseONNX(const char *model_file, | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); | |||
| model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); | |||
| GE_CHECK_NOTNULL(model_parser); | |||
| return ge::SUCCESS; | |||
| } | |||
| graphStatus HandleAfterParse(AclGrphParseUtil &acl_graph_parse_util, | |||
| const std::map<AscendString, AscendString> &parser_params, | |||
| ge::Graph &graph) { | |||
| if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params after graph failed."); | |||
| return ge::FAILED; | |||
| } | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| return ge::SUCCESS; | |||
| } | |||
| graphStatus aclgrphParseONNX(const char *model_file, | |||
| const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | |||
| GE_CHECK_NOTNULL(model_file); | |||
| // load custom plugin so and proto | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| domi::ModelParser *model_parser = nullptr; | |||
| if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "prepare before parse failed."); | |||
| return ge::FAILED; | |||
| } | |||
| GE_CHECK_NOTNULL(model_parser); | |||
| // parse caffe model_file to GE graph | |||
| ge::graphStatus ret = model_parser->Parse(model_file, graph); | |||
| if (ret != ge::SUCCESS) { | |||
| @@ -72,15 +99,11 @@ graphStatus aclgrphParseONNX(const char *model_file, | |||
| } | |||
| GELOGI("Parser graph %s success.", graph.GetName().c_str()); | |||
| if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params after graph failed."); | |||
| if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "handle after parse failed."); | |||
| return ge::FAILED; | |||
| } | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||
| return ge::SUCCESS; | |||
| } | |||
| @@ -88,27 +111,14 @@ graphStatus aclgrphParseONNX(const char *model_file, | |||
| graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, | |||
| const std::map<AscendString, AscendString> &parser_params, ge::Graph &graph) { | |||
| GE_CHECK_NOTNULL(buffer); | |||
| GetParserContext().type = domi::ONNX; | |||
| std::map<string, string> options; | |||
| options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::ONNX))); | |||
| // load custom plugin so and proto | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| (void)acl_graph_parse_util.AclParserInitialize(options); | |||
| domi::ModelParser *model_parser = nullptr; | |||
| string output_name; | |||
| if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params before graph failed."); | |||
| if (PrepareBeforeParse(acl_graph_parse_util, parser_params, graph, model_parser) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "prepare before parse failed."); | |||
| return ge::FAILED; | |||
| } | |||
| // Create an empty computegraph | |||
| string graph_name = output_name.empty() ? "tmpGraph" : output_name; | |||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>(graph_name); | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::ONNX); | |||
| GE_CHECK_NOTNULL(model_parser); | |||
| // parse caffe model_file to GE graph | |||
| ge::graphStatus ret = model_parser->ParseFromMemory(buffer, (uint32_t)size, graph); | |||
| @@ -118,17 +128,12 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, | |||
| } | |||
| GELOGI("Parser graph %s success.", graph.GetName().c_str()); | |||
| if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Parser params after graph failed."); | |||
| if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "handle after parse failed."); | |||
| return ge::FAILED; | |||
| } | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| } | |||
| GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||
| return ge::SUCCESS; | |||
| GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); | |||
| return ge::SUCCESS; | |||
| } | |||
| } // namespace ge | |||