From 44762b6ea5c82594efeb042a127d85c39748db28 Mon Sep 17 00:00:00 2001 From: l00444296 Date: Wed, 11 Nov 2020 14:02:41 +0800 Subject: [PATCH] Feature:Support user options of aclgrphParse interface --- inc/external/parser/caffe_parser.h | 3 +- inc/external/parser/tensorflow_parser.h | 3 +- parser/caffe/caffe_parser.cc | 7 +++-- parser/common/acl_graph_parser_util.cc | 37 +++++++++++++++++-------- parser/common/acl_graph_parser_util.h | 11 ++++---- parser/tensorflow/tensorflow_parser.cc | 4 +-- 6 files changed, 42 insertions(+), 23 deletions(-) diff --git a/inc/external/parser/caffe_parser.h b/inc/external/parser/caffe_parser.h index 9749756..d4c069e 100644 --- a/inc/external/parser/caffe_parser.h +++ b/inc/external/parser/caffe_parser.h @@ -24,12 +24,13 @@ #include "graph/ge_error_codes.h" #include "graph/types.h" #include "graph/graph.h" +#include "graph/ascend_string.h" namespace ge { graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph); graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, - const std::map &parser_params, + const std::map &parser_params, ge::Graph &graph); } // namespace ge diff --git a/inc/external/parser/tensorflow_parser.h b/inc/external/parser/tensorflow_parser.h index b8c0ead..9f03097 100644 --- a/inc/external/parser/tensorflow_parser.h +++ b/inc/external/parser/tensorflow_parser.h @@ -25,11 +25,12 @@ #include "graph/ge_error_codes.h" #include "graph/types.h" #include "graph/graph.h" +#include "graph/ascend_string.h" namespace ge { graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph); graphStatus aclgrphParseTensorFlow(const char *model_file, - const std::map &parser_params, + const std::map &parser_params, ge::Graph &graph); } // namespace ge diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index 7bfae75..9dcd0b7 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -108,7 +108,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, return ret; } GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); - std::map parser_params; + std::map parser_params; if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); return ge::FAILED; @@ -116,8 +116,9 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, return ge::SUCCESS; } -graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, - const std::map &parser_params, +graphStatus aclgrphParseCaffe(const char *model_file, + const char *weights_file, + const std::map &parser_params, ge::Graph &graph) { GE_CHECK_NOTNULL(model_file); GetParserContext().type = domi::CAFFE; diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc index c11cc37..5a776b3 100644 --- a/parser/common/acl_graph_parser_util.cc +++ b/parser/common/acl_graph_parser_util.cc @@ -138,11 +138,18 @@ static void GetOpsProtoPath(string &opsproto_path) { opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); } -static void GetAclParams(const std::map &parser_params, +static void GetAclParams(const std::map &parser_params, const string &key, string &value) { - auto iter = parser_params.find(key); + ge::AscendString tmp_key(key.c_str()); + auto iter = parser_params.find(tmp_key); if (iter != parser_params.end()) { - value = iter->second; + ge::AscendString tmp_value = iter->second; + const char* value_ascend = tmp_value.GetString(); + if (value_ascend == nullptr) { + value = ""; + } else { + value = value_ascend; + } } else { value = ""; } @@ -883,7 +890,7 @@ domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_gr } domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph, - const std::map &parser_params) { + const std::map &parser_params) { ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); @@ -976,14 +983,22 @@ domi::Status AclGrphParseUtil::ParseAclLogLevel(const std::string &log) { return domi::SUCCESS; } -domi::Status AclGrphParseUtil::CheckOptions(const std::map &parser_params) { +domi::Status AclGrphParseUtil::CheckOptions(const std::map &parser_params) { for (auto &ele : parser_params) { - auto it = ge::ir_option::ir_parser_suppported_options.find(ele.first); + const char * key_ascend = ele.first.GetString(); + if (key_ascend == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"parser_params", "null AscendString"}); + GELOGE(PARAM_INVALID, "input options key is null, Please check!"); + return PARAM_INVALID; + } + + string key_str = key_ascend; + auto it = ge::ir_option::ir_parser_suppported_options.find(key_str); if (it == ge::ir_option::ir_parser_suppported_options.end()) { ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, - {"parser_params", ele.first}); - GELOGE(PARAM_INVALID, "input options include unsupported option(%s).Please check!", - ele.first.c_str()); + {"parser_params", key_str}); + GELOGE(PARAM_INVALID, "input options include unsupported option(%s).Please check!", key_ascend); return PARAM_INVALID; } } @@ -1071,7 +1086,7 @@ domi::Status AclGrphParseUtil::CheckAclOpNameMap(const ComputeGraphPtr &graph, c return SUCCESS; } -domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map &parser_params, +domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map &parser_params, string &graph_name) { GELOGI("Parse graph user options start."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS, @@ -1127,7 +1142,7 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map &parser_params) { + const std::map &parser_params) { // support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, compress_weight_conf, ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); diff --git a/parser/common/acl_graph_parser_util.h b/parser/common/acl_graph_parser_util.h index 1d5d3ff..edbfb36 100644 --- a/parser/common/acl_graph_parser_util.h +++ b/parser/common/acl_graph_parser_util.h @@ -27,6 +27,7 @@ #include "framework/omg/parser/parser_types.h" #include "register/register_error_codes.h" #include "graph/utils/graph_utils.h" +#include "graph/ascend_string.h" namespace ge { @@ -40,16 +41,16 @@ class AclGrphParseUtil { void SaveCustomCaffeProtoPath(); domi::Status AclParserInitialize(const std::map &options); domi::Status SetOutputNodeInfo(ge::Graph &graph, - const std::map &parser_params); - domi::Status ParseParamsBeforeGraph(const std::map &parser_params, + const std::map &parser_params); + domi::Status ParseParamsBeforeGraph(const std::map &parser_params, std::string &graph_name); domi::Status ParseParamsAfterGraph(ge::Graph &graph, - const std::map &parser_params); + const std::map &parser_params); domi::Status ParseOutputInfo(ge::Graph &graph, - const std::map &parser_params); + const std::map &parser_params); private: bool parser_initialized = false; - domi::Status CheckOptions(const std::map &parser_params); + domi::Status CheckOptions(const std::map &parser_params); domi::Status GetOutputLeaf(NodePtr node, std::vector> &output_nodes_info); void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, std::vector &output_nodes_name); diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 1a29652..0a70f38 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -113,7 +113,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { return ge::FAILED; } - std::map parser_params; + std::map parser_params; if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); return ge::FAILED; @@ -123,7 +123,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { } graphStatus aclgrphParseTensorFlow(const char *model_file, - const std::map &parser_params, + const std::map &parser_params, ge::Graph &graph) { GE_CHECK_NOTNULL(model_file); GetParserContext().type = domi::TENSORFLOW;