diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc index abc5d4d..555e451 100644 --- a/parser/common/acl_graph_parser_util.cc +++ b/parser/common/acl_graph_parser_util.cc @@ -73,14 +73,17 @@ const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\"" const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; static std::set kCaffeSupportInputFormatSet = {"NCHW", "ND"}; static std::set kTfSupportInputFormatSet = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"}; -static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; -static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; +const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; +const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; /// The maximum length of the file. /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 const int kMaxFileSizeLimit = INT_MAX; const int kMaxBuffSize = 256; const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M +const int kOutputTypeNode = 0; +const int kOutputTypeIndex = 1; +const int kOutputTypeDataType = 2; vector SplitInputShape(const std::string &input_shape) { vector shape_pair_vec; @@ -137,7 +140,7 @@ static void GetOpsProtoPath(string &opsproto_path) { } static void GetAclParams(const std::map &parser_params, - const string &key, string &value) { + const string &key, string &value) { auto iter = parser_params.find(key); if (iter != parser_params.end()) { value = iter->second; @@ -375,7 +378,7 @@ bool AclGrphParseUtil::CheckAclInputFormat(string &input_format) { GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", input_format.c_str(), kCaffeFormatSupport); return false; - } else if (ge::GetParserContext().type ==domi::TENSORFLOW) { // tf + } else if (ge::GetParserContext().type == domi::TENSORFLOW) { // tf if (kTfSupportInputFormatSet.find(input_format) != kTfSupportInputFormatSet.end()) { return true; } @@ -494,7 +497,7 @@ bool AclGrphParseUtil::ParseInputShape(const string &input_shape, return true; } -//Parse user input shape info +// Parse user input shape info domi::Status AclGrphParseUtil::ParseAclShape(const string &input_shape, bool is_dynamic_input) { ge::GetParserContext().input_dims.clear(); ge::GetParserContext().user_input_dims.clear(); @@ -541,7 +544,8 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) { ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, {"--out_nodes", out_nodes, "is not all index or top_name"}); - GELOGE(PARAM_INVALID, "This out_nodes str must be all index or top_name, while the actual input is %s", out_nodes.c_str()); + GELOGE(PARAM_INVALID, + "This out_nodes str must be all index or top_name, while the actual input is %s", out_nodes.c_str()); return PARAM_INVALID; } // stoi: The method may throw an exception: invalid_argument/out_of_range @@ -647,8 +651,8 @@ domi::Status AclGrphParseUtil::ParseAclEnableScope(const string &enable_scope_fu } void AclGrphParseUtil::AddAttrsForInputNodes(const vector &adjust_fp16_format_vec, - const string &fp16_nodes_name, uint32_t index, - OpDescPtr &op_desc) { + const string &fp16_nodes_name, uint32_t index, + OpDescPtr &op_desc) { if (AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) { if ((index < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[index] == "true")) { GELOGI("This node [%s] should be set NC1HWC0", fp16_nodes_name.c_str()); @@ -727,6 +731,10 @@ domi::Status AclGrphParseUtil::ParseAclWeightCompressConf(const ComputeGraphPtr std::string compress_nodes; ifs >> compress_nodes; ifs.close(); + if (compress_nodes.empty()) { + GELOGW("Compress weight of nodes info is empty"); + return SUCCESS; + } GELOGI("Compress weight of nodes: %s", compress_nodes.c_str()); vector compress_node_vec = StringUtils::Split(compress_nodes, ';'); @@ -763,14 +771,14 @@ domi::Status AclGrphParseUtil::ParseAclOutputType(const std::string &output_type return domi::FAILED; } ge::DataType tmp_dt; - std::string node_name = StringUtils::Trim(node_index_type_v[0]); - std::string index_str = StringUtils::Trim(node_index_type_v[1]); + std::string node_name = StringUtils::Trim(node_index_type_v[kOutputTypeNode]); + std::string index_str = StringUtils::Trim(node_index_type_v[kOutputTypeIndex]); int32_t index; if (StringToInt(index_str, index) != SUCCESS) { GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str()); return domi::FAILED; } - std::string dt_value = StringUtils::Trim(node_index_type_v[2]); + std::string dt_value = StringUtils::Trim(node_index_type_v[kOutputTypeDataType]); auto it = kOutputTypeSupportDatatype.find(dt_value); if (it == kOutputTypeSupportDatatype.end()) { ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, @@ -811,7 +819,8 @@ void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vectorGetName(); int32_t index = output_nodes_info[i].second; if (i < ge::GetParserContext().out_top_names.size()) { - output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + ge::GetParserContext().out_top_names[i]); + output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + + ge::GetParserContext().out_top_names[i]); } else { GELOGW("Get top name of node [%s] fail.", node_name.c_str()); output_nodes_name.push_back(node_name + ":" + std::to_string(index)); @@ -982,6 +991,86 @@ domi::Status AclGrphParseUtil::CheckOptions(const std::mapGetDirectNode()) + { + if (node->GetType() == ge::parser::DATA) + { + auto data_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(data_op_desc); + auto tensor_desc = data_op_desc->MutableInputDesc(0); + GE_CHECK_NOTNULL(tensor_desc); + for (auto dim : tensor_desc->GetShape().GetDims()) + { + if (dim < 0) + { + GELOGE(PARAM_INVALID, + "Input op [%s] shape %ld is negative, maybe you should set input_shape to specify its shape", + node->GetName().c_str(), dim); + const string reason = "maybe you should set input_shape to specify its shape"; + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {node->GetName(), to_string(dim), reason}); + return PARAM_INVALID; + } + } + } + } + } + for (auto it : ge::GetParserContext().user_input_dims) + { + std::string node_name = it.first; + ge::NodePtr node = graph->FindNode(node_name); + if (node == nullptr) + { + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_shape", node_name}); + GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not exist in model", node_name.c_str()); + return PARAM_INVALID; + } + if (node->GetType() != ge::parser::DATA) + { + ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_shape", node_name}); + GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not a input opname", node_name.c_str()); + return PARAM_INVALID; + } + } + return SUCCESS; +} + +domi::Status AclGrphParseUtil::CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf) { + GE_CHECK_NOTNULL(graph); + unordered_map graphNodeTypes; + for (const NodePtr &node : graph->GetAllNodes()) + { + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) + { + GELOGE(PARAM_INVALID, "Invalid parameter for opDesc."); + return PARAM_INVALID; + } + graphNodeTypes[op_desc->GetType()] = ""; + } + std::map &propertiesMap = ge::GetParserContext().op_conf_map; + if (propertiesMap.empty()) + { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10003", {"parameter", "value", "reason"}, {"op_name_map", op_conf, "the file content is empty"}); + GELOGE(PARAM_INVALID, "op_name_map file content is empty, please check file!"); + return PARAM_INVALID; + } + for (auto iter = propertiesMap.begin(); iter != propertiesMap.end(); iter++) + { + GE_IF_BOOL_EXEC(graphNodeTypes.find(iter->second) == graphNodeTypes.end(), + ErrorManager::GetInstance().ATCReportErrMessage( + "E10003", + {"parameter", "value", "reason"}, + {"op_name_map", op_conf, "type[" + iter->second + "] is not found in model"}); + GELOGE(PARAM_INVALID, "Invalid parameter for op_name_map."); + return PARAM_INVALID;); + } + return SUCCESS; +} domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map &parser_params, string &graph_name) { @@ -995,7 +1084,6 @@ domi::Status AclGrphParseUtil::ParseParamsBeforeGraph(const std::map> &output_node_dt_map); domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, std::vector> &output_nodes_info); + domi::Status CheckAclInputShapeNode(const ComputeGraphPtr &graph, const bool is_dynamic_input); + domi::Status CheckAclOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf); }; -namespace parser { -/// -/// @ingroup: domi_common -/// @brief: get length of file -/// @param [in] input_file: path of file -/// @return long: File length. If the file length fails to be obtained, the value -1 is returned. -/// -extern long GetFileLength(const std::string &input_file); + namespace parser + { + /// + /// @ingroup: domi_common + /// @brief: get length of file + /// @param [in] input_file: path of file + /// @return long: File length. If the file length fails to be obtained, the value -1 is returned. + /// + extern long GetFileLength(const std::string &input_file); -/// -/// @ingroup domi_common -/// @brief Absolute path for obtaining files. -/// @param [in] path of input file -/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned -/// -std::string RealPath(const char *path); + /// + /// @ingroup domi_common + /// @brief Absolute path for obtaining files. + /// @param [in] path of input file + /// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned + /// + std::string RealPath(const char *path); -/// -/// @ingroup domi_common -/// @brief Obtains the absolute time (timestamp) of the current system. -/// @return Timestamp, in microseconds (US) -/// -/// -uint64_t GetCurrentTimestamp(); + /// + /// @ingroup domi_common + /// @brief Obtains the absolute time (timestamp) of the current system. + /// @return Timestamp, in microseconds (US) + /// + /// + uint64_t GetCurrentTimestamp(); -/// -/// @ingroup domi_common -/// @brief Reads all data from a binary file. -/// @param [in] file_name path of file -/// @param [out] buffer Output memory address, which needs to be released by the caller. -/// @param [out] length Output memory size -/// @return false fail -/// @return true success -/// -bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length); + /// + /// @ingroup domi_common + /// @brief Reads all data from a binary file. + /// @param [in] file_name path of file + /// @param [out] buffer Output memory address, which needs to be released by the caller. + /// @param [out] length Output memory size + /// @return false fail + /// @return true success + /// + bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length); -/// -/// @ingroup domi_common -/// @brief proto file in bianary format -/// @param [in] file path of proto file -/// @param [out] proto memory for storing the proto file -/// @return true success -/// @return false fail -/// -bool ReadProtoFromBinaryFile(const char *file, Message *proto); + /// + /// @ingroup domi_common + /// @brief proto file in bianary format + /// @param [in] file path of proto file + /// @param [out] proto memory for storing the proto file + /// @return true success + /// @return false fail + /// + bool ReadProtoFromBinaryFile(const char *file, Message *proto); -/// -/// @ingroup domi_common -/// @brief Reads the proto structure from an array. -/// @param [in] data proto data to be read -/// @param [in] size proto data size -/// @param [out] proto Memory for storing the proto file -/// @return true success -/// @return false fail -/// -bool ReadProtoFromArray(const void *data, int size, Message *proto); + /// + /// @ingroup domi_common + /// @brief Reads the proto structure from an array. + /// @param [in] data proto data to be read + /// @param [in] size proto data size + /// @param [out] proto Memory for storing the proto file + /// @return true success + /// @return false fail + /// + bool ReadProtoFromArray(const void *data, int size, Message *proto); -/// -/// @ingroup domi_proto -/// @brief Reads the proto file in the text format. -/// @param [in] file path of proto file -/// @param [out] message Memory for storing the proto file -/// @return true success -/// @return false fail -/// -bool ReadProtoFromText(const char *file, google::protobuf::Message *message); + /// + /// @ingroup domi_proto + /// @brief Reads the proto file in the text format. + /// @param [in] file path of proto file + /// @param [out] message Memory for storing the proto file + /// @return true success + /// @return false fail + /// + bool ReadProtoFromText(const char *file, google::protobuf::Message *message); -bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message); + bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message); -/// -/// @brief get the Original Type of FrameworkOp -/// @param [in] node -/// @param [out] type -/// @return Status -/// -domi::Status GetOriginalType(const ge::NodePtr &node, string &type); + /// + /// @brief get the Original Type of FrameworkOp + /// @param [in] node + /// @param [out] type + /// @return Status + /// + domi::Status GetOriginalType(const ge::NodePtr &node, string &type); -/// -/// @ingroup domi_common -/// @brief Check whether the file path meets the whitelist verification requirements. -/// @param [in] filePath file path -/// @param [out] result -/// -bool ValidateStr(const std::string &filePath, const std::string &mode); + /// + /// @ingroup domi_common + /// @brief Check whether the file path meets the whitelist verification requirements. + /// @param [in] filePath file path + /// @param [out] result + /// + bool ValidateStr(const std::string &filePath, const std::string &mode); -/// -/// @ingroup domi_common -/// @brief Obtains the current time string. -/// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555 -/// -std::string CurrentTimeInStr(); + /// + /// @ingroup domi_common + /// @brief Obtains the current time string. + /// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555 + /// + std::string CurrentTimeInStr(); -template -static inline std::shared_ptr MakeShared(Args &&... args) { - typedef typename std::remove_const::type T_nc; - std::shared_ptr ret(new (std::nothrow) T_nc(std::forward(args)...)); - return ret; + template + static inline std::shared_ptr MakeShared(Args &&... args) + { + typedef typename std::remove_const::type T_nc; + std::shared_ptr ret(new (std::nothrow) T_nc(std::forward(args)...)); + return ret; } /// @ingroup math_util