Browse Source

Feature:Support user options of aclgrphParse interface

pull/63/head
l00444296 5 years ago
parent
commit
44762b6ea5
6 changed files with 42 additions and 23 deletions
  1. +2
    -1
      inc/external/parser/caffe_parser.h
  2. +2
    -1
      inc/external/parser/tensorflow_parser.h
  3. +4
    -3
      parser/caffe/caffe_parser.cc
  4. +26
    -11
      parser/common/acl_graph_parser_util.cc
  5. +6
    -5
      parser/common/acl_graph_parser_util.h
  6. +2
    -2
      parser/tensorflow/tensorflow_parser.cc

+ 2
- 1
inc/external/parser/caffe_parser.h View File

@@ -24,12 +24,13 @@
#include "graph/ge_error_codes.h"
#include "graph/types.h"
#include "graph/graph.h"
#include "graph/ascend_string.h"

namespace ge {
graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph);

graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
const std::map<std::string, std::string> &parser_params,
const std::map<ge::AscendString, ge::AscendString> &parser_params,
ge::Graph &graph);
} // namespace ge



+ 2
- 1
inc/external/parser/tensorflow_parser.h View File

@@ -25,11 +25,12 @@
#include "graph/ge_error_codes.h"
#include "graph/types.h"
#include "graph/graph.h"
#include "graph/ascend_string.h"

namespace ge {
graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph);
graphStatus aclgrphParseTensorFlow(const char *model_file,
const std::map<std::string, std::string> &parser_params,
const std::map<ge::AscendString, ge::AscendString> &parser_params,
ge::Graph &graph);
} // namespace ge


+ 4
- 3
parser/caffe/caffe_parser.cc View File

@@ -108,7 +108,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
return ret;
}
GELOGI("Weights parse success. graph: %s", graph.GetName().c_str());
std::map<string, string> parser_params;
std::map<AscendString, AscendString> parser_params;
if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
@@ -116,8 +116,9 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
return ge::SUCCESS;
}

graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
const std::map<std::string, std::string> &parser_params,
graphStatus aclgrphParseCaffe(const char *model_file,
const char *weights_file,
const std::map<AscendString, AscendString> &parser_params,
ge::Graph &graph) {
GE_CHECK_NOTNULL(model_file);
GetParserContext().type = domi::CAFFE;


+ 26
- 11
parser/common/acl_graph_parser_util.cc View File

@@ -138,11 +138,18 @@ static void GetOpsProtoPath(string &opsproto_path) {
opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
}

static void GetAclParams(const std::map<string, string> &parser_params,
static void GetAclParams(const std::map<ge::AscendString, ge::AscendString> &parser_params,
const string &key, string &value) {
auto iter = parser_params.find(key);
ge::AscendString tmp_key(key.c_str());
auto iter = parser_params.find(tmp_key);
if (iter != parser_params.end()) {
value = iter->second;
ge::AscendString tmp_value = iter->second;
const char* value_ascend = tmp_value.GetString();
if (value_ascend == nullptr) {
value = "";
} else {
value = value_ascend;
}
} else {
value = "";
}
@@ -883,7 +890,7 @@ domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_gr
}

domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph,
const std::map<std::string, std::string> &parser_params) {
const std::map<AscendString, AscendString> &parser_params) {
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);

@@ -976,14 +983,22 @@ domi::Status AclGrphParseUtil::ParseAclLogLevel(const std::string &log) {
return domi::SUCCESS;
}

domi::Status AclGrphParseUtil::CheckOptions(const std::map<std::string, std::string> &parser_params) {
domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendString> &parser_params) {
for (auto &ele : parser_params) {
auto it = ge::ir_option::ir_parser_suppported_options.find(ele.first);
const char * key_ascend = ele.first.GetString();
if (key_ascend == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"parser_params", "null AscendString"});
GELOGE(PARAM_INVALID, "input options key is null, Please check!");
return PARAM_INVALID;
}

string key_str = key_ascend;
auto it = ge::ir_option::ir_parser_suppported_options.find(key_str);
if (it == ge::ir_option::ir_parser_suppported_options.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"parser_params", ele.first});
GELOGE(PARAM_INVALID, "input options include unsupported option(%s).Please check!",
ele.first.c_str());
{"parser_params", key_str});
GELOGE(PARAM_INVALID, "input options include unsupported option(%s).Please check!", key_ascend);
return PARAM_INVALID;
}
}
@@ -1071,7 +1086,7 @@ domi::Status AclGrphParseUtil::CheckAclOpNameMap(const ComputeGraphPtr &graph, c
return SUCCESS;
}

domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<std::string, std::string> &parser_params,
domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
string &graph_name) {
GELOGI("Parse graph user options start.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS,
@@ -1127,7 +1142,7 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map<std::string
}

domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph,
const std::map<std::string, std::string> &parser_params) {
const std::map<AscendString, AscendString> &parser_params) {
// support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, compress_weight_conf,
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);



+ 6
- 5
parser/common/acl_graph_parser_util.h View File

@@ -27,6 +27,7 @@
#include "framework/omg/parser/parser_types.h"
#include "register/register_error_codes.h"
#include "graph/utils/graph_utils.h"
#include "graph/ascend_string.h"

namespace ge {

@@ -40,16 +41,16 @@ class AclGrphParseUtil {
void SaveCustomCaffeProtoPath();
domi::Status AclParserInitialize(const std::map<std::string, std::string> &options);
domi::Status SetOutputNodeInfo(ge::Graph &graph,
const std::map<std::string, std::string> &parser_params);
domi::Status ParseParamsBeforeGraph(const std::map<std::string, std::string> &parser_params,
const std::map<AscendString, AscendString> &parser_params);
domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
std::string &graph_name);
domi::Status ParseParamsAfterGraph(ge::Graph &graph,
const std::map<std::string, std::string> &parser_params);
const std::map<AscendString, AscendString> &parser_params);
domi::Status ParseOutputInfo(ge::Graph &graph,
const std::map<std::string, std::string> &parser_params);
const std::map<AscendString, AscendString> &parser_params);
private:
bool parser_initialized = false;
domi::Status CheckOptions(const std::map<std::string, std::string> &parser_params);
domi::Status CheckOptions(const std::map<AscendString, AscendString> &parser_params);
domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);
void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);


+ 2
- 2
parser/tensorflow/tensorflow_parser.cc View File

@@ -113,7 +113,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) {
return ge::FAILED;
}

std::map<string, string> parser_params;
std::map<AscendString, AscendString> parser_params;
if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) {
GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
@@ -123,7 +123,7 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) {
}

graphStatus aclgrphParseTensorFlow(const char *model_file,
const std::map<std::string, std::string> &parser_params,
const std::map<AscendString, AscendString> &parser_params,
ge::Graph &graph) {
GE_CHECK_NOTNULL(model_file);
GetParserContext().type = domi::TENSORFLOW;


Loading…
Cancel
Save