| @@ -24,12 +24,13 @@ | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/types.h" | |||
| #include "graph/graph.h" | |||
| #include "graph/ascend_string.h" | |||
| namespace ge { | |||
| graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph); | |||
| graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||
| const std::map<std::string, std::string> &parser_params, | |||
| const std::map<ge::AscendString, ge::AscendString> &parser_params, | |||
| ge::Graph &graph); | |||
| } // namespace ge | |||
| @@ -25,11 +25,12 @@ | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/types.h" | |||
| #include "graph/graph.h" | |||
| #include "graph/ascend_string.h" | |||
| namespace ge { | |||
| graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph); | |||
| graphStatus aclgrphParseTensorFlow(const char *model_file, | |||
| const std::map<std::string, std::string> &parser_params, | |||
| const std::map<ge::AscendString, ge::AscendString> &parser_params, | |||
| ge::Graph &graph); | |||
| } // namespace ge | |||
| @@ -108,7 +108,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||
| return ret; | |||
| } | |||
| GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); | |||
| std::map<string, string> parser_params; | |||
| std::map<AscendString, AscendString> parser_params; | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| @@ -116,8 +116,9 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||
| return ge::SUCCESS; | |||
| } | |||
| graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||
| const std::map<std::string, std::string> &parser_params, | |||
| graphStatus aclgrphParseCaffe(const char *model_file, | |||
| const char *weights_file, | |||
| const std::map<AscendString, AscendString> &parser_params, | |||
| ge::Graph &graph) { | |||
| GE_CHECK_NOTNULL(model_file); | |||
| GetParserContext().type = domi::CAFFE; | |||
| @@ -138,11 +138,18 @@ static void GetOpsProtoPath(string &opsproto_path) { | |||
| opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | |||
| } | |||
| static void GetAclParams(const std::map<string, string> &parser_params, | |||
| static void GetAclParams(const std::map<ge::AscendString, ge::AscendString> &parser_params, | |||
| const string &key, string &value) { | |||
| auto iter = parser_params.find(key); | |||
| ge::AscendString tmp_key(key.c_str()); | |||
| auto iter = parser_params.find(tmp_key); | |||
| if (iter != parser_params.end()) { | |||
| value = iter->second; | |||
| ge::AscendString tmp_value = iter->second; | |||
| const char* value_ascend = tmp_value.GetString(); | |||
| if (value_ascend == nullptr) { | |||
| value = ""; | |||
| } else { | |||
| value = value_ascend; | |||
| } | |||
| } else { | |||
| value = ""; | |||
| } | |||
| @@ -883,7 +890,7 @@ domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_gr | |||
| } | |||
| domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph, | |||
| const std::map<std::string, std::string> &parser_params) { | |||
| const std::map<AscendString, AscendString> &parser_params) { | |||
| ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| @@ -976,14 +983,22 @@ domi::Status AclGrphParseUtil::ParseAclLogLevel(const std::string &log) { | |||
| return domi::SUCCESS; | |||
| } | |||
| domi::Status AclGrphParseUtil::CheckOptions(const std::map<std::string, std::string> &parser_params) { | |||
| domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) { | |||
| for (auto &ele : parser_params) { | |||
| auto it = ge::ir_option::ir_parser_suppported_options.find(ele.first); | |||
| const char * key_ascend = ele.first.GetString(); | |||
| if (key_ascend == nullptr) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||
| {"parser_params", "null AscendString"}); | |||
| GELOGE(PARAM_INVALID, "input options key is null, Please check!"); | |||
| return PARAM_INVALID; | |||
| } | |||
| string key_str = key_ascend; | |||
| auto it = ge::ir_option::ir_parser_suppported_options.find(key_str); | |||
| if (it == ge::ir_option::ir_parser_suppported_options.end()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||
| {"parser_params", ele.first}); | |||
| GELOGE(PARAM_INVALID, "input options include unsupported option(%s).Please check!", | |||
| ele.first.c_str()); | |||
| {"parser_params", key_str}); | |||
| GELOGE(PARAM_INVALID, "input options include unsupported option(%s).Please check!", key_ascend); | |||
| return PARAM_INVALID; | |||
| } | |||
| } | |||
| @@ -1071,7 +1086,7 @@ domi::Status AclGrphParseUtil::CheckAclOpNameMap(const ComputeGraphPtr &graph, c | |||
| return SUCCESS; | |||
| } | |||
| domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<std::string, std::string> &parser_params, | |||
| domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | |||
| string &graph_name) { | |||
| GELOGI("Parse graph user options start."); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS, | |||
| @@ -1127,7 +1142,7 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<std::string | |||
| } | |||
| domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, | |||
| const std::map<std::string, std::string> &parser_params) { | |||
| const std::map<AscendString, AscendString> &parser_params) { | |||
| // support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, compress_weight_conf, | |||
| ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | |||
| @@ -27,6 +27,7 @@ | |||
| #include "framework/omg/parser/parser_types.h" | |||
| #include "register/register_error_codes.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/ascend_string.h" | |||
| namespace ge { | |||
| @@ -40,16 +41,16 @@ class AclGrphParseUtil { | |||
| void SaveCustomCaffeProtoPath(); | |||
| domi::Status AclParserInitialize(const std::map<std::string, std::string> &options); | |||
| domi::Status SetOutputNodeInfo(ge::Graph &graph, | |||
| const std::map<std::string, std::string> &parser_params); | |||
| domi::Status ParseParamsBeforeGraph(const std::map<std::string, std::string> &parser_params, | |||
| const std::map<AscendString, AscendString> &parser_params); | |||
| domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | |||
| std::string &graph_name); | |||
| domi::Status ParseParamsAfterGraph(ge::Graph &graph, | |||
| const std::map<std::string, std::string> &parser_params); | |||
| const std::map<AscendString, AscendString> &parser_params); | |||
| domi::Status ParseOutputInfo(ge::Graph &graph, | |||
| const std::map<std::string, std::string> &parser_params); | |||
| const std::map<AscendString, AscendString> &parser_params); | |||
| private: | |||
| bool parser_initialized = false; | |||
| domi::Status CheckOptions(const std::map<std::string, std::string> &parser_params); | |||
| domi::Status CheckOptions(const std::map<AscendString, AscendString> &parser_params); | |||
| domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||
| void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name); | |||
| @@ -113,7 +113,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { | |||
| return ge::FAILED; | |||
| } | |||
| std::map<string, string> parser_params; | |||
| std::map<AscendString, AscendString> parser_params; | |||
| if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { | |||
| GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||
| return ge::FAILED; | |||
| @@ -123,7 +123,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { | |||
| } | |||
| graphStatus aclgrphParseTensorFlow(const char *model_file, | |||
| const std::map<std::string, std::string> &parser_params, | |||
| const std::map<AscendString, AscendString> &parser_params, | |||
| ge::Graph &graph) { | |||
| GE_CHECK_NOTNULL(model_file); | |||
| GetParserContext().type = domi::TENSORFLOW; | |||