From 9e2dd62f764ae4f92fb6c586d0c6c9151a0791bd Mon Sep 17 00:00:00 2001 From: l00444296 Date: Sat, 14 Nov 2020 18:12:45 +0800 Subject: [PATCH] Feature:Support user options of aclgrphParse interface --- inc/external/parser/caffe_parser.h | 6 +- inc/external/parser/tensorflow_parser.h | 5 +- metadef | 2 +- parser/caffe/caffe_parser.cc | 57 +- parser/common/acl_graph_parser_util.cc | 1046 +++++++++++++++++++++-- parser/common/acl_graph_parser_util.h | 38 +- parser/tensorflow/tensorflow_parser.cc | 48 +- 7 files changed, 1130 insertions(+), 72 deletions(-) diff --git a/inc/external/parser/caffe_parser.h b/inc/external/parser/caffe_parser.h index 2a687d0..f96dbfd 100644 --- a/inc/external/parser/caffe_parser.h +++ b/inc/external/parser/caffe_parser.h @@ -21,12 +21,16 @@ #include #include +#include "graph/ascend_string.h" #include "graph/ge_error_codes.h" -#include "graph/types.h" #include "graph/graph.h" +#include "graph/types.h" namespace ge { graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph); + +graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, + const std::map &parser_params, ge::Graph &graph); } // namespace ge #endif // INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ diff --git a/inc/external/parser/tensorflow_parser.h b/inc/external/parser/tensorflow_parser.h index b7c1c8c..3ff4773 100644 --- a/inc/external/parser/tensorflow_parser.h +++ b/inc/external/parser/tensorflow_parser.h @@ -22,12 +22,15 @@ #include #include +#include "graph/ascend_string.h" #include "graph/ge_error_codes.h" -#include "graph/types.h" #include "graph/graph.h" +#include "graph/types.h" namespace ge { graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph); +graphStatus aclgrphParseTensorFlow(const char *model_file, + const std::map &parser_params, ge::Graph &graph); } // namespace ge #endif // INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ \ No newline at end of file diff --git a/metadef b/metadef index 9bbce07..37465b8 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 9bbce07b846858fa30ef2bd7c662894e20f83ef1 +Subproject commit 37465b85d30b67a0edcc6ea4acd2f11a9697c7af diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index 4088a4c..5c48a0b 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -108,12 +108,67 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, return ret; } GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); - if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) { + std::map parser_params; + if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); return ge::FAILED; } return ge::SUCCESS; } + +graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, + const std::map &parser_params, ge::Graph &graph) { + GE_CHECK_NOTNULL(model_file); + GetParserContext().type = domi::CAFFE; + std::map options; + options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(ge::CAFFE))); + + // load custom plugin so and proto + AclGrphParseUtil acl_graph_parse_util; + (void)acl_graph_parse_util.AclParserInitialize(options); + + string output_name; + if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Parser params before graph failed."); + return ge::FAILED; + } + // Create an empty computegraph + string graph_name = output_name.empty() ? "tmpGraph" : output_name; + ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared(graph_name); + GE_CHECK_NOTNULL(compute_graph); + + graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); + GE_CHECK_NOTNULL(model_parser); + + // parse caffe model_file and weights_file to GE graph + ge::graphStatus ret = model_parser->Parse(model_file, graph); + if (ret != ge::SUCCESS) { + GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); + return ge::FAILED; + } + GELOGI("Parser graph %s success.", graph.GetName().c_str()); + + if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Parser params after graph failed."); + return ge::FAILED; + } + + auto weights_parser = domi::WeightsParserFactory::Instance()->CreateWeightsParser(domi::CAFFE); + ret = weights_parser->Parse(weights_file, graph); + if (ret != ge::SUCCESS) { + GELOGE(ret, "Weights parse failed. graph: %s", graph.GetName().c_str()); + return ret; + } + GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); + + if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); + return ge::FAILED; + } + GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); + return ge::SUCCESS; +} } // namespace ge diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc index 0a16d38..8a2f25c 100644 --- a/parser/common/acl_graph_parser_util.cc +++ b/parser/common/acl_graph_parser_util.cc @@ -17,24 +17,29 @@ #include "parser/common/acl_graph_parser_util.h" #include -#include -#include #include + +#include #include +#include -#include "common/string_util.h" #include "common/debug/log.h" #include "common/op/ge_op_utils.h" -#include "ge/ge_api_types.h" -#include "graph/opsproto_manager.h" -#include "omg/parser/parser_inner_ctx.h" -#include "tbe_plugin_loader.h" +#include "common/string_util.h" +#include "common/types.h" +#include "common/util.h" +#include "common/util/error_manager/error_manager.h" +#include "external/ge/ge_api_types.h" #include "framework/common/debug/ge_log.h" -#include "parser/common/register_tbe.h" #include "framework/omg/parser/parser_types.h" -#include "common/util/error_manager/error_manager.h" +#include "ge/ge_api_types.h" #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/zero_copy_stream_impl.h" +#include "graph/opsproto_manager.h" +#include "graph/utils/type_utils.h" +#include "omg/parser/parser_inner_ctx.h" +#include "parser/common/register_tbe.h" +#include "tbe_plugin_loader.h" using google::protobuf::io::CodedInputStream; using google::protobuf::io::FileInputStream; @@ -42,12 +47,47 @@ using google::protobuf::io::ZeroCopyInputStream; using namespace ge::parser; namespace { +static std::map kInputFormatStrToGeformat = { + {"ND", domi::DOMI_TENSOR_ND}, {"NCHW", domi::DOMI_TENSOR_NCHW}, {"NHWC", domi::DOMI_TENSOR_NHWC}, + {"CHWN", domi::DOMI_TENSOR_CHWN}, {"NC1HWC0", domi::DOMI_TENSOR_NC1HWC0}, {"NHWC1C0", domi::DOMI_TENSOR_NHWC1C0}, + {"NCDHW", domi::DOMI_TENSOR_NCDHW}, {"NDHWC", domi::DOMI_TENSOR_NDHWC}}; + +// datatype/formats from user to GE, Unified to util interface file later +const std::map kOutputTypeSupportDatatype = { + {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; +const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; +const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; +const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; +const char *const kSplitError1 = "size not equal to 2 split by \":\""; +const char *const kEmptyError = "can not be empty"; +const char *const kFloatNumError = "exist float number"; +const char *const kDigitError = "is not digit"; +const std::string kGraphDefaultName = "domi_default"; +const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; +const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; +static std::set kCaffeSupportInputFormatSet = {"NCHW", "ND"}; +static std::set kTfSupportInputFormatSet = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"}; +const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; +const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; /// The maximum length of the file. /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 const int kMaxFileSizeLimit = INT_MAX; const int kMaxBuffSize = 256; -const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. -const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M +const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. +const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M +const int kOutputTypeNode = 0; +const int kOutputTypeIndex = 1; +const int kOutputTypeDataType = 2; + +vector SplitInputShape(const std::string &input_shape) { + vector shape_pair_vec; + size_t pos = input_shape.rfind(":"); + if (pos != std::string::npos) { + shape_pair_vec.emplace_back(input_shape.substr(0, pos)); + shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos)); + } + return shape_pair_vec; +} static string GetSoPath() { Dl_info dl_info; @@ -92,78 +132,165 @@ static void GetOpsProtoPath(string &opsproto_path) { path_base = path_base.substr(0, path_base.rfind('/') + 1); opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); } -} // namespace -namespace ge { -domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, - std::vector> &output_nodes_info) { - ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); - if (tmpDescPtr == nullptr) { - GELOGE(domi::FAILED, "Get outnode op desc fail."); - return domi::FAILED; - } - size_t size = tmpDescPtr->GetOutputsSize(); - if (node->GetType() != NETOUTPUT) { - for (size_t index = 0; index < size; ++index) { - output_nodes_info.push_back(std::make_pair(node, index)); +static void GetAclParams(const std::map &parser_params, const string &key, + string &value) { + for (auto &ele : parser_params) { + const char *key_ascend = ele.first.GetString(); + if (key_ascend == nullptr) { + GELOGW("Input options key is null, Please check!"); + continue; } - } else { - const auto in_anchors = node->GetAllInDataAnchors(); - for (auto in_anchor : in_anchors) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - GELOGE(domi::FAILED, "Get leaf node op desc fail."); - return domi::FAILED; + + string key_str = key_ascend; + if (key == key_str) { + const char *value_ascend = ele.second.GetString(); + if (value_ascend == nullptr) { + value = ""; + } else { + value = value_ascend; } - auto out_node = out_anchor->GetOwnerNode(); - output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx())); + return; } } - return SUCCESS; + value = ""; + return; } -void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, - std::vector &output_nodes_name) { - output_nodes_name.clear(); - if (ge::GetParserContext().out_top_names.empty()) { - // tf process, no top name. - for (const auto output_node_info : output_nodes_info) { - std::string node_name = output_node_info.first->GetName(); - int32_t index = output_node_info.second; - output_nodes_name.push_back(node_name + ":" + std::to_string(index)); +static bool CheckDigitStr(std::string &str) { + for (char c : str) { + if (!isdigit(c)) { + GELOGE(domi::FAILED, "Value[%s] is not positive integer", str.c_str()); + return false; } - return; } - // caffe process reserved place; + return true; } -domi::Status AclGrphParseUtil::SetDefaultOutputNode(ge::Graph &graph) { - ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); - if (compute_graph == nullptr) { - GELOGE(FAILED, "compute_graph is nullptr."); - return FAILED; +// Remove the space and tab before and after the string +std::string TrimConf(const std::string &str) { + if (str.empty()) { + return str; } - std::vector> output_nodes_info; - std::vector output_nodes_name; + std::string::size_type start = str.find_first_not_of(" \t\r\n"); + if (start == std::string::npos) { + return str; + } - for (ge::NodePtr node : compute_graph->GetDirectNode()) { - if (!node->GetInAllNodes().empty() && node->GetOutAllNodes().empty()) { - Status ret = AclGrphParseUtil::GetOutputLeaf(node, output_nodes_info); - if (ret != SUCCESS) { - GELOGE(FAILED, "find leaf fail."); - return FAILED; - } + std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1; + return str.substr(start, end); +} + +// Parsing the command line +bool ParseSingleLine(const std::string &line, std::map &op_conf_map) { + std::string temp = TrimConf(line); + std::string delimiter = ":"; + // Comment or newline returns true directly + if (temp.find_first_of('#') == 0 || *(temp.c_str()) == '\n') { + return true; + } + + if (!temp.empty()) { + std::string::size_type pos = temp.find_first_of(delimiter); + if (pos == std::string::npos) { + GELOGE(ge::PARAM_INVALID, "Incorrect line [%s], it must include [%s].Perhaps you use illegal chinese symbol", + line.c_str(), delimiter.c_str()); + return false; } + + std::string map_key = TrimConf(temp.substr(0, pos)); + std::string value = TrimConf(temp.substr(pos + 1)); + if (map_key.empty() || value.empty()) { + GELOGE(ge::PARAM_INVALID, "Map_key or value empty. %s", line.c_str()); + return false; + } + + op_conf_map[map_key] = value; } + return true; +} - AclGrphParseUtil::GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); - compute_graph->SetGraphOutNodesInfo(output_nodes_info); - ge::GetParserContext().net_out_nodes = output_nodes_name; - GELOGI("Set graph %s default output node success.", graph.GetName().c_str()); +} // namespace + +namespace ge { +static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_param) { + if ((s == "true") || (s == "false")) { + return true; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {atc_param, s}); + GELOGE(PARAM_INVALID, "Input parameter[%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str()); + return false; + } +} + +static domi::Status CheckOutPutDataTypeSupport(const std::string &output_type) { + auto it = kOutputTypeSupportDatatype.find(output_type); + if (it == kOutputTypeSupportDatatype.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"output_type", output_type, kOutputTypeSupport}); + GELOGE(PARAM_INVALID, "Invalid value for output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport); + return domi::PARAM_INVALID; + } + return domi::SUCCESS; +} + +static domi::Status StringToInt(std::string &str, int32_t &value) { + try { + if (!CheckDigitStr(str)) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"output_type", str, "is not positive integer"}); + return PARAM_INVALID; + } + value = stoi(str); + } catch (std::invalid_argument &) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"output_type", str}); + return PARAM_INVALID; + } catch (std::out_of_range &) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"output_type", str}); + return PARAM_INVALID; + } return SUCCESS; } +static Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { + int32_t out_size = op_desc->GetOutputsSize(); + if (index < 0 || index >= out_size) { + GELOGE(domi::FAILED, + "out_node [%s] output index:%d must be smaller " + "than node output size:%d and can not be negative!", + op_desc->GetName().c_str(), index, out_size); + std::string fail_reason = "output index:" + to_string(index) + + " must be smaller than output size:" + to_string(out_size) + " and can not be negative!"; + ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, + {"out_nodes", op_desc->GetName(), fail_reason}); + return domi::FAILED; + } + return domi::SUCCESS; +} + +domi::Status VerifyOutputTypeAndOutNodes(std::vector &out_type_vec) { + std::vector> user_out_nodes = ge::GetParserContext().user_out_nodes; + std::set out_nodes_info; + for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { + // out_nodes set should include output_type and output_format + std::string tmp = user_out_nodes[i].first + ":" + to_string(user_out_nodes[i].second); + out_nodes_info.emplace(tmp); + } + for (uint32_t i = 0; i < out_type_vec.size(); ++i) { + if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"output_type", out_type_vec[i], kOutputTypeError}); + GELOGE(domi::FAILED, "Invalid value for output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError); + return domi::FAILED; + } + } + return domi::SUCCESS; +} + domi::Status AclGrphParseUtil::LoadOpsProtoLib() { string opsproto_path; GetOpsProtoPath(opsproto_path); @@ -239,6 +366,799 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::mapsecond; + } else { + GELOGE(PARAM_INVALID, "Input format %s not support , expect ND/NCHW/NHWC/CHWN/NC1HWC0/NHWC1C0.", + input_format.c_str()); + return PARAM_INVALID; + } + } + return SUCCESS; +} + +bool AclGrphParseUtil::ParseInputShape(const string &input_shape, + std::unordered_map> &shape_map, + vector>> &user_shape_map, bool is_dynamic_input) { + vector shape_vec = StringUtils::Split(input_shape, ';'); + const int DEFAULT_SHAPE_PAIR_SIZE = 2; + for (const auto &shape : shape_vec) { + vector shape_pair_vec = SplitInputShape(shape); + if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kSplitError1, kInputShapeSample1}); + GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", shape.c_str(), + kSplitError1, kInputShapeSample1); + return false; + } + if (shape_pair_vec[1].empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kEmptyError, kInputShapeSample1}); + GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", shape.c_str(), + kEmptyError, kInputShapeSample1); + return false; + } + + vector shape_value_strs = StringUtils::Split(shape_pair_vec[1], ','); + vector shape_values; + for (auto &shape_value_str : shape_value_strs) { + // stoul: The method may throw an exception: invalid_argument/out_of_range + if (std::string::npos != shape_value_str.find('.')) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kFloatNumError, kInputShapeSample2}); + GELOGW("Parse input parameter [input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kFloatNumError, kInputShapeSample2); + return false; + } + + long left_result = 0; + try { + left_result = stol(StringUtils::Trim(shape_value_str)); + if (!shape_value_str.empty() && (shape_value_str.front() == '-')) { + // The value maybe dynamic shape [-1], need substr it and verify isdigit. + shape_value_str = shape_value_str.substr(1); + } + for (char c : shape_value_str) { + if (!isdigit(c)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kDigitError, kInputShapeSample2}); + GELOGE(PARAM_INVALID, "input_shape's shape value[%s] is not digit", shape_value_str.c_str()); + return false; + } + } + } catch (const std::out_of_range &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, + {"input_shape", shape_value_str}); + GELOGW("Input parameter[input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str()); + return false; + } catch (const std::invalid_argument &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, + {"input_shape", shape_value_str}); + GELOGW("Input parameter[input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str()); + return false; + } catch (...) { + ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"}, + {"input_shape", shape_value_str}); + GELOGW("Input parameter[input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str()); + return false; + } + int64_t result = left_result; + // - 1 is not currently supported + if (!is_dynamic_input && result <= 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape", "result"}, {shape, std::to_string(result)}); + GELOGW( + "Input parameter[input_shape]’s shape value[%s] is invalid, " + "expect positive integer, but value is %ld.", + shape.c_str(), result); + return false; + } + shape_values.push_back(result); + } + + shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); + user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); + } + + return true; +} + +// Parse user input shape info +domi::Status AclGrphParseUtil::ParseAclShape(const string &input_shape, bool is_dynamic_input) { + ge::GetParserContext().input_dims.clear(); + ge::GetParserContext().user_input_dims.clear(); + ge::GetParserContext().is_dynamic_input = is_dynamic_input; + + if (input_shape.empty()) { + return SUCCESS; + } + + std::unordered_map> &shape_map = ge::GetParserContext().input_dims; + if (!ParseInputShape(input_shape, ge::GetParserContext().input_dims, ge::GetParserContext().user_input_dims, + is_dynamic_input) || + shape_map.empty()) { + GELOGE(PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + +domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { + try { + // parse output node + if (!out_nodes.empty()) { + ge::GetParserContext().out_nodes_map.clear(); + ge::GetParserContext().user_out_nodes.clear(); + ge::GetParserContext().user_out_nodes_top_vec.clear(); + + vector nodes_v = StringUtils::Split(out_nodes, ';'); + for (const string &node : nodes_v) { + vector key_value_v = StringUtils::Split(node, ':'); + if (key_value_v.size() != 2) { // The size must be 2. + if (key_value_v.size() == 1 && ge::GetParserContext().type == domi::CAFFE) { + ge::GetParserContext().user_out_nodes_top_vec.push_back(node); + continue; + } + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""}); + GELOGE(PARAM_INVALID, + "The input format of out_nodes is invalid, the correct format is " + "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.", + node.c_str()); + return PARAM_INVALID; + } + if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"out_nodes", out_nodes, "is not all index or top_name"}); + GELOGE(PARAM_INVALID, "This out_nodes str must be all index or top_name, while the actual input is %s", + out_nodes.c_str()); + return PARAM_INVALID; + } + // stoi: The method may throw an exception: invalid_argument/out_of_range + if (!CheckDigitStr(key_value_v[1])) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"out_nodes", out_nodes, "is not positive integer"}); + GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str()); + return PARAM_INVALID; + } + + auto iter = ge::GetParserContext().out_nodes_map.find(key_value_v[0]); + int32_t index = stoi(StringUtils::Trim(key_value_v[1])); + GELOGD("Get output info: node[%s] and index[%ld]", key_value_v[0].c_str(), index); + if (iter != ge::GetParserContext().out_nodes_map.end()) { + iter->second.emplace_back(index); + } else { + std::vector index_v; + index_v.emplace_back(index); + ge::GetParserContext().out_nodes_map.emplace(key_value_v[0], index_v); + } + ge::GetParserContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index)); + } + } + } catch (std::invalid_argument &) { + GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"out_nodes", out_nodes}); + return PARAM_INVALID; + } catch (std::out_of_range &) { + GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"out_nodes", out_nodes}); + return PARAM_INVALID; + } + return SUCCESS; +} + +domi::Status AclGrphParseUtil::ParseAclOutputFp16NodesFormat(const string &is_output_fp16) { + if (is_output_fp16.empty()) { + return SUCCESS; + } + + vector &output_formats = ge::GetParserContext().output_formats; + output_formats.clear(); + vector node_format_vec = StringUtils::Split(is_output_fp16, ','); + for (auto &is_fp16 : node_format_vec) { + StringUtils::Trim(is_fp16); + if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) { + GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]", + is_output_fp16.c_str()); + return PARAM_INVALID; + } + if (is_fp16 == "false") { + output_formats.push_back(DOMI_TENSOR_ND); + } else if (is_fp16 == "true") { + output_formats.push_back(domi::DOMI_TENSOR_NC1HWC0); + } + } + return SUCCESS; +} + +domi::Status AclGrphParseUtil::ParseAclOpConf(const std::string &op_conf) { + if (op_conf.empty()) { + return SUCCESS; + } + // Normalize the path + string resolved_file_path = ge::parser::RealPath(op_conf.c_str()); + if (resolved_file_path.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"op_map_conf", op_conf}); + GELOGE(domi::FAILED, "Invalid input file path [%s], make sure that the file path is correct.", op_conf.c_str()); + return FAILED; + } + std::ifstream fs(resolved_file_path, std::ifstream::in); + + if (!fs.is_open()) { + GELOGE(PARAM_INVALID, "Open %s failed.", op_conf.c_str()); + return FAILED; + } + + std::string line; + + while (getline(fs, line)) { // line not with \n + if (!ParseSingleLine(line, ge::GetParserContext().op_conf_map)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"op_map_conf_line_info", line}); + GELOGE(PARAM_INVALID, "Parse line failed. content is [%s].", line.c_str()); + fs.close(); + return FAILED; + } + } + fs.close(); // close the file + + GELOGI("LoadFileContent success."); + return SUCCESS; +} + +domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fusion_passes) { + ge::GetParserContext().enable_scope_fusion_passes.clear(); + if (enable_scope_fusion_passes.empty()) { + return SUCCESS; + } + ge::GetParserContext().enable_scope_fusion_passes = enable_scope_fusion_passes; + return SUCCESS; +} + +void AclGrphParseUtil::AddAttrsForInputNodes(const vector &adjust_fp16_format_vec, + const string &fp16_nodes_name, uint32_t index, OpDescPtr &op_desc) { + if (AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) { + if ((index < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[index] == "true")) { + GELOGI("This node [%s] should be set NC1HWC0", fp16_nodes_name.c_str()); + if (!AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_FORMAT, TypeUtils::FormatToSerialString(FORMAT_NC1HWC0))) { + GELOGW("This node [%s] set NC1HWC0 failed", fp16_nodes_name.c_str()); + } + } + } +} + +domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, + const string &is_input_adjust_hw_layout) { + GE_CHECK_NOTNULL(graph); + vector adjust_fp16_format_vec; + if (!is_input_adjust_hw_layout.empty()) { + adjust_fp16_format_vec = StringUtils::Split(is_input_adjust_hw_layout, ','); + for (auto &s : adjust_fp16_format_vec) { + StringUtils::Trim(s); + if (!CheckInputTrueOrFalse(s, "is_input_adjust_hw_layout")) { + GELOGE(PARAM_INVALID, "Invalid Param, is_input_adjust_hw_layout only support true/false: but is [%s]", + is_input_adjust_hw_layout.c_str()); + return PARAM_INVALID; + } + } + } + if (input_fp16_nodes.empty()) { + return SUCCESS; + } + GELOGI("The input_fp16_nodes is set %s", input_fp16_nodes.c_str()); + vector input_fp16_nodes_vec = StringUtils::Split(input_fp16_nodes, ';'); + for (uint32_t i = 0; i < input_fp16_nodes_vec.size(); ++i) { + ge::NodePtr node = graph->FindNode(input_fp16_nodes_vec[i]); + if (node == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"input_fp16_nodes", input_fp16_nodes_vec[i]}); + GELOGE(PARAM_INVALID, "Input parameter[input_fp16_nodes]'s opname[%s] is not exist in model", + input_fp16_nodes_vec[i].c_str()); + return PARAM_INVALID; + } + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (op_desc->GetType() != ge::parser::DATA) { + ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, + {"input_fp16_nodes", input_fp16_nodes_vec[i]}); + GELOGE(PARAM_INVALID, "Input parameter[input_fp16_nodes]'s opname[%s] is not a input opname", + input_fp16_nodes_vec[i].c_str()); + return PARAM_INVALID; + } + AddAttrsForInputNodes(adjust_fp16_format_vec, input_fp16_nodes_vec[i], i, op_desc); + } + return SUCCESS; +} + +domi::Status AclGrphParseUtil::ParseAclWeightCompressConf(const ComputeGraphPtr &graph, + const string &compress_weight_conf) { + GE_CHECK_NOTNULL(graph); + if (compress_weight_conf.empty()) { + return SUCCESS; + } + std::string real_path = ge::parser::RealPath(compress_weight_conf.c_str()); + if (real_path.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"compress_weight_conf", compress_weight_conf}); + GELOGE(PARAM_INVALID, "Can not get real path for %s.", compress_weight_conf.c_str()); + return PARAM_INVALID; + } + std::ifstream ifs(real_path); + if (!ifs.is_open()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"compress_weight_conf", compress_weight_conf}); + GELOGE(FAILED, "Open file %s failed", compress_weight_conf.c_str()); + return FAILED; + } + + std::string compress_nodes; + ifs >> compress_nodes; + ifs.close(); + if (compress_nodes.empty()) { + GELOGW("Compress weight of nodes info is empty"); + return SUCCESS; + } + GELOGI("Compress weight of nodes: %s", compress_nodes.c_str()); + + vector compress_node_vec = StringUtils::Split(compress_nodes, ';'); + for (size_t i = 0; i < compress_node_vec.size(); ++i) { + ge::NodePtr node = graph->FindNode(compress_node_vec[i]); + if (node == nullptr) { + GELOGW("Node %s is not in graph", compress_node_vec[i].c_str()); + continue; + } + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) { + GELOGE(domi::FAILED, "Node %s SetBool failed.", compress_node_vec[i].c_str()); + return domi::FAILED; + } + } + return SUCCESS; +} + +domi::Status AclGrphParseUtil::ParseAclOutputType(const std::string &output_type, + std::map> &output_node_dt_map) { + if (output_type.find(':') == std::string::npos) { + GELOGI("output_type is not multiple nodes, means all out nodes"); + return CheckOutPutDataTypeSupport(output_type); + } + std::vector out_type_vec; + vector nodes_v = StringUtils::Split(output_type, ';'); + for (const string &node : nodes_v) { + vector node_index_type_v = StringUtils::Split(node, ':'); + if (node_index_type_v.size() != 3) { // The size must be 3. + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"output_type", node, kOutputTypeSample}); + GELOGE(PARAM_INVALID, "Invalid value for output_type[%s], %s.", node.c_str(), kOutputTypeSample); + return domi::FAILED; + } + ge::DataType tmp_dt; + std::string node_name = StringUtils::Trim(node_index_type_v[kOutputTypeNode]); + std::string index_str = StringUtils::Trim(node_index_type_v[kOutputTypeIndex]); + int32_t index; + if (StringToInt(index_str, index) != SUCCESS) { + GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str()); + return domi::FAILED; + } + std::string dt_value = StringUtils::Trim(node_index_type_v[kOutputTypeDataType]); + auto it = kOutputTypeSupportDatatype.find(dt_value); + if (it == kOutputTypeSupportDatatype.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"output_type", dt_value, kOutputTypeSupport}); + GELOGE(ge::PARAM_INVALID, "Invalid value for output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport); + return domi::FAILED; + } else { + tmp_dt = it->second; + } + out_type_vec.push_back(node_name + ":" + index_str); + std::string index_dt_str = index_str + ":" + TypeUtils::DataTypeToSerialString(tmp_dt); + auto it1 = output_node_dt_map.find(node_name); + if (it1 == output_node_dt_map.end()) { + vector tmp_vec; + tmp_vec.push_back(index_dt_str); + output_node_dt_map.emplace(node_name, tmp_vec); + } else { + it1->second.push_back(index_dt_str); + } + } + return VerifyOutputTypeAndOutNodes(out_type_vec); +} + +void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, + std::vector &output_nodes_name) { + output_nodes_name.clear(); + if (ge::GetParserContext().out_top_names.empty()) { + // tf process, no top name. + for (const auto output_node_info : output_nodes_info) { + std::string node_name = output_node_info.first->GetName(); + int32_t index = output_node_info.second; + output_nodes_name.push_back(node_name + ":" + std::to_string(index)); + } + return; + } + // caffe process, need add top name after node_name:index + for (size_t i = 0; i < output_nodes_info.size(); ++i) { + std::string node_name = output_nodes_info[i].first->GetName(); + int32_t index = output_nodes_info[i].second; + if (i < ge::GetParserContext().out_top_names.size()) { + output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + + ge::GetParserContext().out_top_names[i]); + } else { + GELOGW("Get top name of node [%s] fail.", node_name.c_str()); + output_nodes_name.push_back(node_name + ":" + std::to_string(index)); + } + } +} + +domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, + std::vector> &output_nodes_info) { + ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); + if (tmpDescPtr == nullptr) { + GELOGE(domi::FAILED, "Get outnode op desc fail."); + return domi::FAILED; + } + size_t size = tmpDescPtr->GetOutputsSize(); + if (node->GetType() != ge::parser::NETOUTPUT) { + for (size_t index = 0; index < size; ++index) { + output_nodes_info.push_back(std::make_pair(node, index)); + GELOGD("Get output leaf node:%s.", node->GetName().c_str()); + } + } else { + const auto in_anchors = node->GetAllInDataAnchors(); + for (auto in_anchor : in_anchors) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + GELOGE(domi::FAILED, "Get leaf node op desc fail."); + return domi::FAILED; + } + auto out_node = out_anchor->GetOwnerNode(); + output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx())); + } + } + return SUCCESS; +} + +domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, + std::vector> &output_nodes_info) { + std::vector> default_out_nodes = ge::GetParserContext().default_out_nodes; + if (ge::GetParserContext().type == domi::CAFFE && !default_out_nodes.empty()) { + for (uint32_t i = 0; i < default_out_nodes.size(); ++i) { + ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first); + if (out_node == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"out_nodes", default_out_nodes[i].first}); + GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", default_out_nodes[i].first.c_str()); + return domi::FAILED; + } + output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second)); + GELOGD("Get default output node:%s.", out_node->GetName().c_str()); + } + return domi::SUCCESS; + } + + for (ge::NodePtr node : compute_graph->GetDirectNode()) { + if (!node->GetInAllNodes().empty() && node->GetOutAllNodes().empty()) { + Status ret = GetOutputLeaf(node, output_nodes_info); + GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Find leaf fail."); + } + } + return domi::SUCCESS; +} + +domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph, + const std::map &parser_params) { + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + string output_type; + GetAclParams(parser_params, ge::ir_option::OUTPUT_TYPE, output_type); + + std::vector> user_out_nodes = ge::GetParserContext().user_out_nodes; + std::vector output_formats = ge::GetParserContext().output_formats; + std::vector> output_nodes_info; + std::vector output_nodes_name; + std::map> output_node_dt_map; + if (!output_type.empty()) { + if (ParseAclOutputType(output_type, output_node_dt_map) != SUCCESS) { + GELOGE(domi::FAILED, "Parse output_type failed."); + return domi::FAILED; + } + } + + // User declared outputs + for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { + ge::NodePtr out_node = compute_graph->FindNode(user_out_nodes[i].first); + if (out_node == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"out_nodes", user_out_nodes[i].first}); + GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str()); + return domi::FAILED; + } + auto op_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) { + GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str()); + return domi::FAILED; + } + + // add user_define_output_nodes attr. + (void)ge::AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_OUTPUT_NODES, "true"); + + if (i < output_formats.size()) { + if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) { + GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str()); + vector output_fp16_5hd_vec; + (void)ge::AttrUtils::GetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec); + output_fp16_5hd_vec.push_back(std::to_string(user_out_nodes[i].second) + ":" + "NC1HWC0"); + (void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_fp16_5hd", output_fp16_5hd_vec); + } + } + auto it = output_node_dt_map.find(user_out_nodes[i].first); + if (it != output_node_dt_map.end()) { + GELOGI("The output node [%s] need to be set output_type", user_out_nodes[i].first.c_str()); + (void)ge::AttrUtils::SetListStr(op_desc, "_user_defined_output_data_type", it->second); + } + output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second)); + } + // default output node (leaf) + if (user_out_nodes.empty()) { + if (GetDefaultOutInfo(compute_graph, output_nodes_info) != SUCCESS) { + GELOGE(domi::FAILED, "Get default output info failed."); + return domi::FAILED; + } + } + GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); + compute_graph->SetGraphOutNodesInfo(output_nodes_info); + ge::GetParserContext().net_out_nodes = output_nodes_name; + GELOGI("Set graph %s output node success.", graph.GetName().c_str()); + return domi::SUCCESS; +} + +domi::Status AclGrphParseUtil::ParseAclLogLevel(const std::string &log) { + if (log.empty()) { + return SUCCESS; + } + int ret = -1; + if (log == "default") { + ret = 0; + } else if (log == "null") { + ret = dlog_setlevel(-1, DLOG_NULL, 0); + } else if (log == "debug") { + ret = dlog_setlevel(-1, DLOG_DEBUG, 1); + } else if (log == "info") { + ret = dlog_setlevel(-1, DLOG_INFO, 1); + } else if (log == "warning") { + ret = dlog_setlevel(-1, DLOG_WARN, 1); + } else if (log == "error") { + ret = dlog_setlevel(-1, DLOG_ERROR, 1); + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"log", log}); + GELOGE(PARAM_INVALID, "Invalid value for log:%s, only support debug, info, warning, error, null", log.c_str()); + return PARAM_INVALID; + } + if (ret != 0) { + GELOGE(PARAM_INVALID, "Log setlevel fail !"); + } + return domi::SUCCESS; +} + +domi::Status AclGrphParseUtil::CheckOptions(const std::map &parser_params) { + for (auto &ele : parser_params) { + const char *key_ascend = ele.first.GetString(); + if (key_ascend == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + {"parser_params", "null AscendString"}); + GELOGE(PARAM_INVALID, "Input options key is null, Please check!"); + return PARAM_INVALID; + } + + string key_str = key_ascend; + auto it = ge::ir_option::ir_parser_suppported_options.find(key_str); + if (it == ge::ir_option::ir_parser_suppported_options.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"parser_params", key_str}); + GELOGE(PARAM_INVALID, "Input options include unsupported option(%s).Please check!", key_ascend); + return PARAM_INVALID; + } + } + return SUCCESS; +} + +domi::Status AclGrphParseUtil::CheckAclInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input) { + if (!is_dynamic_input) { + for (auto node : graph->GetDirectNode()) { + if (node->GetType() == ge::parser::DATA) { + auto data_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(data_op_desc); + auto tensor_desc = data_op_desc->MutableInputDesc(0); + GE_CHECK_NOTNULL(tensor_desc); + for (auto dim : tensor_desc->GetShape().GetDims()) { + if (dim < 0) { + GELOGE(PARAM_INVALID, + "Input op [%s] shape %ld is negative, maybe you should set input_shape to specify its shape", + node->GetName().c_str(), dim); + const string reason = "maybe you should set input_shape to specify its shape"; + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {node->GetName(), to_string(dim), reason}); + return PARAM_INVALID; + } + } + } + } + } + for (auto it : ge::GetParserContext().user_input_dims) { + std::string node_name = it.first; + ge::NodePtr node = graph->FindNode(node_name); + if (node == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_shape", node_name}); + GELOGE(PARAM_INVALID, "Input parameter[input_shape]'s opname[%s] is not exist in model", node_name.c_str()); + return PARAM_INVALID; + } + if (node->GetType() != ge::parser::DATA) { + ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_shape", node_name}); + GELOGE(PARAM_INVALID, "Input parameter[input_shape]'s opname[%s] is not a input opname", node_name.c_str()); + return PARAM_INVALID; + } + } + return SUCCESS; +} + +domi::Status AclGrphParseUtil::CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf) { + GE_CHECK_NOTNULL(graph); + unordered_map graphNodeTypes; + for (const NodePtr &node : graph->GetAllNodes()) { + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(PARAM_INVALID, "Invalid parameter for opDesc."); + return PARAM_INVALID; + } + graphNodeTypes[op_desc->GetType()] = ""; + } + std::map &propertiesMap = ge::GetParserContext().op_conf_map; + if (propertiesMap.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, + {"op_name_map", op_conf, "the file content is empty"}); + GELOGE(PARAM_INVALID, "op_name_map file content is empty, please check file!"); + return PARAM_INVALID; + } + for (auto iter = propertiesMap.begin(); iter != propertiesMap.end(); iter++) { + GE_IF_BOOL_EXEC(graphNodeTypes.find(iter->second) == graphNodeTypes.end(), + ErrorManager::GetInstance().ATCReportErrMessage( + "E10003", {"parameter", "value", "reason"}, + {"op_name_map", op_conf, "type[" + iter->second + "] is not found in model"}); + GELOGE(PARAM_INVALID, "Invalid parameter for op_name_map."); return PARAM_INVALID;); + } + return SUCCESS; +} + +domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map &parser_params, + string &graph_name) { + GELOGI("Parse graph user options start."); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckOptions(parser_params) != SUCCESS, return PARAM_INVALID, + "Parse paragrams invalid."); + // support paragrams: log, input_format, is_dynamic_input, input_shape, out_nodes + // is_output_adjust_hw_layout, output, op_name_map, enable_scope_fusion_passes + string log_level; + GetAclParams(parser_params, ge::ir_option::LOG_LEVEL, log_level); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclLogLevel(log_level) != SUCCESS, return PARAM_INVALID, + "Parse log_level failed"); + + string input_format; + GetAclParams(parser_params, ge::ir_option::INPUT_FORMAT, input_format); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclFormat(input_format) != SUCCESS, return PARAM_INVALID, + "Parse input_format failed"); + + string dynamic_input_str; + GetAclParams(parser_params, ge::ir_option::IS_DYNAMIC_INPUT, dynamic_input_str); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + !dynamic_input_str.empty() && !CheckInputTrueOrFalse(dynamic_input_str, "is_dynamic_input"), return PARAM_INVALID, + "Parse is_dynamic_input failed"); + bool is_dynamic_input = dynamic_input_str == "true" ? true : false; + + string input_shape; + GetAclParams(parser_params, ge::ir_option::INPUT_SHAPE, input_shape); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclShape(input_shape, is_dynamic_input) != SUCCESS, return PARAM_INVALID, + "Parse input_shape failed"); + + string out_nodes; + GetAclParams(parser_params, ge::ir_option::OUT_NODES, out_nodes); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputNodes(out_nodes) != SUCCESS, return PARAM_INVALID, + "Parse out_nodes failed"); + + string is_output_adjust_hw_layout; + GetAclParams(parser_params, ge::ir_option::IS_OUTPUT_ADJUST_HW_LAYOUT, is_output_adjust_hw_layout); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOutputFp16NodesFormat(is_output_adjust_hw_layout) != SUCCESS, + return PARAM_INVALID, "Parse is_output_adjust_hw_layout failed"); + + string op_conf_str; + GetAclParams(parser_params, ge::ir_option::OP_NAME_MAP, op_conf_str); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclOpConf(op_conf_str) != SUCCESS, return PARAM_INVALID, + "Parse op_name_map failed"); + + string tmp_name; + GetAclParams(parser_params, ge::ir_option::OUTPUT, tmp_name); + graph_name = tmp_name.empty() ? (kGraphDefaultName + "_" + ge::parser::CurrentTimeInStr()) : tmp_name; + + string enable_scope_fusion_passes; + GetAclParams(parser_params, ge::ir_option::ENABLE_SCOPE_FUSION_PASSES, enable_scope_fusion_passes); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclEnableScope(enable_scope_fusion_passes) != SUCCESS, return PARAM_INVALID, + "Parse enable_scope_fusion_passes failed"); + + return SUCCESS; +} + +domi::Status AclGrphParseUtil::ParseParamsAfterGraph(ge::Graph &graph, + const std::map &parser_params) { + // support paragrams: input_fp16_nodes, is_input_adjust_hw_layout, compress_weight_conf, + ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); + + string input_fp16_nodes; + GetAclParams(parser_params, ge::ir_option::INPUT_FP16_NODES, input_fp16_nodes); + + string is_input_adjust_hw_layout; + GetAclParams(parser_params, ge::ir_option::IS_INPUT_ADJUST_HW_LAYOUT, is_input_adjust_hw_layout); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ParseAclInputFp16Nodes(compute_graph, input_fp16_nodes, is_input_adjust_hw_layout) != SUCCESS, + return PARAM_INVALID, "Parse input_fp16_nodes failed"); + + bool is_dynamic_input = ge::GetParserContext().is_dynamic_input; + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckAclInputShapeNode(compute_graph, is_dynamic_input) != SUCCESS, + return PARAM_INVALID, "Check nodes input_shape info failed"); + + string compress_weight_conf; + GetAclParams(parser_params, ge::ir_option::COMPRESS_WEIGHT_CONF, compress_weight_conf); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ParseAclWeightCompressConf(compute_graph, compress_weight_conf) != SUCCESS, + return PARAM_INVALID, "Parse compress_weight_conf failed"); + string op_conf_str; + GetAclParams(parser_params, ge::ir_option::OP_NAME_MAP, op_conf_str); + if (!op_conf_str.empty()) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(CheckAclOpNameMap(compute_graph, op_conf_str) != SUCCESS, return PARAM_INVALID, + "Check op_name_map info failed"); + } + + return SUCCESS; +} + namespace parser { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { if (path == nullptr) { diff --git a/parser/common/acl_graph_parser_util.h b/parser/common/acl_graph_parser_util.h index 8c182e1..a68cc03 100644 --- a/parser/common/acl_graph_parser_util.h +++ b/parser/common/acl_graph_parser_util.h @@ -17,14 +17,18 @@ #ifndef ACL_GRAPH_PARSE_UTIL_ #define ACL_GRAPH_PARSE_UTIL_ -#include -#include #include + +#include #include +#include +#include +#include #include "framework/omg/parser/parser_types.h" -#include "register/register_error_codes.h" +#include "graph/ascend_string.h" #include "graph/utils/graph_utils.h" +#include "register/register_error_codes.h" namespace ge { @@ -37,13 +41,39 @@ class AclGrphParseUtil { domi::Status LoadOpsProtoLib(); void SaveCustomCaffeProtoPath(); domi::Status AclParserInitialize(const std::map &options); - domi::Status SetDefaultOutputNode(ge::Graph &graph); + domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map &parser_params); + domi::Status ParseParamsBeforeGraph(const std::map &parser_params, + std::string &graph_name); + domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map &parser_params); + domi::Status ParseOutputInfo(ge::Graph &graph, const std::map &parser_params); private: bool parser_initialized = false; + domi::Status CheckOptions(const std::map &parser_params); domi::Status GetOutputLeaf(NodePtr node, std::vector> &output_nodes_info); void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, std::vector &output_nodes_name); + domi::Status ParseAclLogLevel(const std::string &log); + bool CheckAclInputFormat(string &input_format); + domi::Status ParseAclFormat(std::string &input_format); + bool ParseInputShape(const std::string &input_shape, std::unordered_map> &shape_map, + vector>> &user_shape_map, bool is_dynamic_input); + domi::Status ParseAclShape(const std::string &input_shape, bool is_dynamic_input); + domi::Status ParseAclOutputNodes(const std::string &out_nodes); + domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16); + domi::Status ParseAclOpConf(const std::string &op_conf); + domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes); + static void AddAttrsForInputNodes(const vector &adjust_fp16_format_vec, const string &fp16_nodes_name, + uint32_t index, OpDescPtr &op_desc); + domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, + const string &is_input_adjust_hw_layout); + domi::Status ParseAclWeightCompressConf(const ComputeGraphPtr &graph, const string &compress_weight_conf); + domi::Status ParseAclOutputType(const std::string &output_type, + std::map> &output_node_dt_map); + domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, + std::vector> &output_nodes_info); + domi::Status CheckAclInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input); + domi::Status CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf); }; namespace parser { diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index cdf727e..1800d00 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -113,13 +113,59 @@ graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { return ge::FAILED; } - if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) { + std::map parser_params; + if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); return ge::FAILED; } GELOGI("Parser graph %s success.", graph.GetName().c_str()); return ge::SUCCESS; } + +graphStatus aclgrphParseTensorFlow(const char *model_file, const std::map &parser_params, + ge::Graph &graph) { + GE_CHECK_NOTNULL(model_file); + GetParserContext().type = domi::TENSORFLOW; + std::map options; + options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(ge::TENSORFLOW))); + + // load custom plugin so and proto + AclGrphParseUtil acl_graph_parse_util; + (void)acl_graph_parse_util.AclParserInitialize(options); + + string output_name; + if (acl_graph_parse_util.ParseParamsBeforeGraph(parser_params, output_name) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Parser params before graph failed."); + return ge::FAILED; + } + // Create an empty computegraph + string graph_name = output_name.empty() ? "tmpGraph" : output_name; + ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared(graph_name); + GE_CHECK_NOTNULL(compute_graph); + + graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); + GE_CHECK_NOTNULL(model_parser); + + // parse tensorflow model_file to GE graph + ge::graphStatus ret = model_parser->Parse(model_file, graph); + if (ret != ge::SUCCESS) { + GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); + return ge::FAILED; + } + + if (acl_graph_parse_util.ParseParamsAfterGraph(graph, parser_params) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Parser params after graph failed."); + return ge::FAILED; + } + + if (acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Set graph %s default output node failed.", graph.GetName().c_str()); + return ge::FAILED; + } + GELOGI("AclgrphParse graph %s success.", graph.GetName().c_str()); + return ge::SUCCESS; +} } // namespace ge namespace ge {