Browse Source

Feature:Support user options of aclgrphParse interface

pull/63/head
l00444296 5 years ago
parent
commit
27ddf9e8e2
1 changed files with 64 additions and 55 deletions
  1. +64
    -55
      parser/common/acl_graph_parser_util.cc

+ 64
- 55
parser/common/acl_graph_parser_util.cc View File

@@ -140,25 +140,34 @@ static void GetOpsProtoPath(string &opsproto_path) {

static void GetAclParams(const std::map<ge::AscendString, ge::AscendString> &parser_params,
const string &key, string &value) {
ge::AscendString tmp_key(key.c_str());
auto iter = parser_params.find(tmp_key);
if (iter != parser_params.end()) {
ge::AscendString tmp_value = iter->second;
const char* value_ascend = tmp_value.GetString();
if (value_ascend == nullptr) {
value = "";
} else {
value = value_ascend;
for (auto &ele : parser_params) {
const char *key_ascend = ele.first.GetString();
if (key_ascend == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"parser_params", "null AscendString"});
GELOGE(PARAM_INVALID, "Input options key is null, Please check!");
return PARAM_INVALID;
}

string key_str = key_ascend;
if (key == key_str) {
const char *value_ascend = ele.second.GetString();
if (value_ascend == nullptr) {
value = "";
} else {
value = value_ascend;
}
return;
}
} else {
value = "";
}
value = "";
return;
}

static bool CheckDigitStr(std::string &str) {
for (char c : str) {
if (!isdigit(c)) {
GELOGE(domi::FAILED, "value[%s] is not positive integer", str.c_str());
GELOGE(domi::FAILED, "Value[%s] is not positive integer", str.c_str());
return false;
}
}
@@ -217,7 +226,7 @@ static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_p
return true;
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {atc_param, s});
GELOGE(PARAM_INVALID, "Input parameter[--%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str());
GELOGE(PARAM_INVALID, "Input parameter[%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str());
return false;
}
}
@@ -226,8 +235,8 @@ static domi::Status CheckOutPutDataTypeSupport(const std::string &output_type) {
auto it = kOutputTypeSupportDatatype.find(output_type);
if (it == kOutputTypeSupportDatatype.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", output_type, kOutputTypeSupport});
GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport);
{"output_type", output_type, kOutputTypeSupport});
GELOGE(PARAM_INVALID, "Invalid value for output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport);
return domi::PARAM_INVALID;
}
return domi::SUCCESS;
@@ -238,17 +247,17 @@ static domi::Status StringToInt(std::string &str, int32_t &value) {
if (!CheckDigitStr(str)) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", str, "is not positive integer"});
{"output_type", str, "is not positive integer"});
return PARAM_INVALID;
}
value = stoi(str);
} catch (std::invalid_argument &) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"--output_type", str});
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"output_type", str});
return PARAM_INVALID;
} catch (std::out_of_range &) {
GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"--output_type", str});
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"output_type", str});
return PARAM_INVALID;
}
return SUCCESS;
@@ -281,8 +290,8 @@ domi::Status VerifyOutputTypeAndOutNodes(std::vector<std::string> &out_type_vec)
for (uint32_t i = 0; i < out_type_vec.size(); ++i) {
if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", out_type_vec[i], kOutputTypeError});
GELOGE(domi::FAILED, "Invalid value for --output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError);
{"output_type", out_type_vec[i], kOutputTypeError});
GELOGE(domi::FAILED, "Invalid value for output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError);
return domi::FAILED;
}
}
@@ -380,9 +389,9 @@ bool AclGrphParseUtil::CheckAclInputFormat(string &input_format) {
}
// only support NCHW ND
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, kCaffeFormatSupport});
"E10001", {"parameter", "value", "reason"}, {"input_format", input_format, kCaffeFormatSupport});
GELOGE(ge::FAILED,
"Invalid value for --input_format[%s], %s.", input_format.c_str(), kCaffeFormatSupport);
"Invalid value for input_format[%s], %s.", input_format.c_str(), kCaffeFormatSupport);
return false;
} else if (ge::GetParserContext().type == domi::TENSORFLOW) { // tf
if (kTfSupportInputFormatSet.find(input_format) != kTfSupportInputFormatSet.end()) {
@@ -390,9 +399,9 @@ bool AclGrphParseUtil::CheckAclInputFormat(string &input_format) {
}
// only support NCHW NHWC ND NCDHW NDHWC
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"}, {"--input_format", input_format, kTFFormatSupport});
"E10001", {"parameter", "value", "reason"}, {"input_format", input_format, kTFFormatSupport});
GELOGE(ge::FAILED,
"Invalid value for --input_format[%s], %s.", input_format.c_str(), kTFFormatSupport);
"Invalid value for input_format[%s], %s.", input_format.c_str(), kTFFormatSupport);
return false;
}
return true;
@@ -428,14 +437,14 @@ bool AclGrphParseUtil::ParseInputShape(const string &input_shape,
if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kSplitError1, kInputShapeSample1});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kSplitError1, kInputShapeSample1);
return false;
}
if (shape_pair_vec[1].empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kEmptyError, kInputShapeSample1});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kEmptyError, kInputShapeSample1);
return false;
}
@@ -447,7 +456,7 @@ bool AclGrphParseUtil::ParseInputShape(const string &input_shape,
if (std::string::npos != shape_value_str.find('.')) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kFloatNumError, kInputShapeSample2});
GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
shape.c_str(), kFloatNumError, kInputShapeSample2);
return false;
}
@@ -463,24 +472,24 @@ bool AclGrphParseUtil::ParseInputShape(const string &input_shape,
if (!isdigit(c)) {
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
{shape, kDigitError, kInputShapeSample2});
GELOGE(PARAM_INVALID, "--input_shape's shape value[%s] is not digit", shape_value_str.c_str());
GELOGE(PARAM_INVALID, "input_shape's shape value[%s] is not digit", shape_value_str.c_str());
return false;
}
}
} catch (const std::out_of_range &) {
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"},
{"--input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str());
{"input_shape", shape_value_str});
GELOGW("Input parameter[input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str());
return false;
} catch (const std::invalid_argument &) {
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"},
{"--input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str());
{"input_shape", shape_value_str});
GELOGW("Input parameter[input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str());
return false;
} catch (...) {
ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"},
{"--input_shape", shape_value_str});
GELOGW("Input parameter[--input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str());
{"input_shape", shape_value_str});
GELOGW("Input parameter[input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str());
return false;
}
int64_t result = left_result;
@@ -488,7 +497,7 @@ bool AclGrphParseUtil::ParseInputShape(const string &input_shape,
if (!is_dynamic_input && result <= 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape", "result"}, {shape, std::to_string(result)});
GELOGW(
"Input parameter[--input_shape]’s shape value[%s] is invalid, "
"Input parameter[input_shape]’s shape value[%s] is invalid, "
"expect positive integer, but value is %ld.",
shape.c_str(), result);
return false;
@@ -540,16 +549,16 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) {
}
ErrorManager::GetInstance().ATCReportErrMessage(
"E10001", {"parameter", "value", "reason"},
{"--out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""});
{"out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""});
GELOGE(PARAM_INVALID,
"The input format of --out_nodes is invalid, the correct format is "
"The input format of out_nodes is invalid, the correct format is "
"\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.",
node.c_str());
return PARAM_INVALID;
}
if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--out_nodes", out_nodes, "is not all index or top_name"});
{"out_nodes", out_nodes, "is not all index or top_name"});
GELOGE(PARAM_INVALID,
"This out_nodes str must be all index or top_name, while the actual input is %s", out_nodes.c_str());
return PARAM_INVALID;
@@ -557,7 +566,7 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) {
// stoi: The method may throw an exception: invalid_argument/out_of_range
if (!CheckDigitStr(key_value_v[1])) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--out_nodes", out_nodes, "is not positive integer"});
{"out_nodes", out_nodes, "is not positive integer"});
GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str());
return PARAM_INVALID;
}
@@ -577,11 +586,11 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) {
}
} catch (std::invalid_argument &) {
GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"--out_nodes", out_nodes});
ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"out_nodes", out_nodes});
return PARAM_INVALID;
} catch (std::out_of_range &) {
GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"--out_nodes", out_nodes});
ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"out_nodes", out_nodes});
return PARAM_INVALID;
}
return SUCCESS;
@@ -695,7 +704,7 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra
if (node == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"input_fp16_nodes", input_fp16_nodes_vec[i]});
GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not exist in model",
GELOGE(PARAM_INVALID, "Input parameter[input_fp16_nodes]'s opname[%s] is not exist in model",
input_fp16_nodes_vec[i].c_str());
return PARAM_INVALID;
}
@@ -704,7 +713,7 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra
if (op_desc->GetType() != ge::parser::DATA) {
ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"},
{"input_fp16_nodes", input_fp16_nodes_vec[i]});
GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not a input opname",
GELOGE(PARAM_INVALID, "Input parameter[input_fp16_nodes]'s opname[%s] is not a input opname",
input_fp16_nodes_vec[i].c_str());
return PARAM_INVALID;
}
@@ -747,13 +756,13 @@ domi::Status AclGrphParseUtil::ParseAclWeightCompressConf(const ComputeGraphPtr
for (size_t i = 0; i < compress_node_vec.size(); ++i) {
ge::NodePtr node = graph->FindNode(compress_node_vec[i]);
if (node == nullptr) {
GELOGW("node %s is not in graph", compress_node_vec[i].c_str());
GELOGW("Node %s is not in graph", compress_node_vec[i].c_str());
continue;
}
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) {
GELOGE(domi::FAILED, "node %s SetBool failed.", compress_node_vec[i].c_str());
GELOGE(domi::FAILED, "Node %s SetBool failed.", compress_node_vec[i].c_str());
return domi::FAILED;
}
}
@@ -772,8 +781,8 @@ domi::Status AclGrphParseUtil::ParseAclOutputType(const std::string &output_type
vector<string> node_index_type_v = StringUtils::Split(node, ':');
if (node_index_type_v.size() != 3) { // The size must be 3.
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", node, kOutputTypeSample});
GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample);
{"output_type", node, kOutputTypeSample});
GELOGE(PARAM_INVALID, "Invalid value for output_type[%s], %s.", node.c_str(), kOutputTypeSample);
return domi::FAILED;
}
ge::DataType tmp_dt;
@@ -788,8 +797,8 @@ domi::Status AclGrphParseUtil::ParseAclOutputType(const std::string &output_type
auto it = kOutputTypeSupportDatatype.find(dt_value);
if (it == kOutputTypeSupportDatatype.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--output_type", dt_value, kOutputTypeSupport});
GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport);
{"output_type", dt_value, kOutputTypeSupport});
GELOGE(ge::PARAM_INVALID, "Invalid value for output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport);
return domi::FAILED;
} else {
tmp_dt = it->second;
@@ -883,7 +892,7 @@ domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_gr
for (ge::NodePtr node : compute_graph->GetDirectNode()) {
if (!node->GetInAllNodes().empty() && node->GetOutAllNodes().empty()) {
Status ret = GetOutputLeaf(node, output_nodes_info);
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "find leaf fail.");
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Find leaf fail.");
}
}
return domi::SUCCESS;
@@ -974,7 +983,7 @@ domi::Status AclGrphParseUtil::ParseAclLogLevel(const std::string &log) {
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"log", log});
GELOGE(PARAM_INVALID, "invalid value for log:%s, only support debug, info, warning, error, null", log.c_str());
GELOGE(PARAM_INVALID, "Invalid value for log:%s, only support debug, info, warning, error, null", log.c_str());
return PARAM_INVALID;
}
if (ret != 0) {
@@ -989,7 +998,7 @@ domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendS
if (key_ascend == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"parser_params", "null AscendString"});
GELOGE(PARAM_INVALID, "input options key is null, Please check!");
GELOGE(PARAM_INVALID, "Input options key is null, Please check!");
return PARAM_INVALID;
}

@@ -998,7 +1007,7 @@ domi::Status AclGrphParseUtil::CheckOptions(const std::map<AscendString, AscendS
if (it == ge::ir_option::ir_parser_suppported_options.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"parser_params", key_str});
GELOGE(PARAM_INVALID, "input options include unsupported option(%s).Please check!", key_ascend);
GELOGE(PARAM_INVALID, "Input options include unsupported option(%s).Please check!", key_ascend);
return PARAM_INVALID;
}
}
@@ -1039,13 +1048,13 @@ domi::Status AclGrphParseUtil::CheckAclInputShapeNode(const ComputeGraphPtr &gra
if (node == nullptr)
{
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_shape", node_name});
GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not exist in model", node_name.c_str());
GELOGE(PARAM_INVALID, "Input parameter[input_shape]'s opname[%s] is not exist in model", node_name.c_str());
return PARAM_INVALID;
}
if (node->GetType() != ge::parser::DATA)
{
ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_shape", node_name});
GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not a input opname", node_name.c_str());
GELOGE(PARAM_INVALID, "Input parameter[input_shape]'s opname[%s] is not a input opname", node_name.c_str());
return PARAM_INVALID;
}
}


Loading…
Cancel
Save