diff --git a/CMakeLists.txt b/CMakeLists.txt index 17cedff..2c0aa4b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,7 @@ if (ENABLE_OPEN_SRC) find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) elseif(ENABLE_GE_COV OR ENABLE_GE_UT) message(STATUS "Runing on llt mode, no need to depend other component") - elseif(ENABLE_PARSER_UT OR ENABLE_PARSER_COV) + elseif(ENABLE_PARSER_UT OR ENABLE_PARSER_COV OR ENABLE_PARSER_ST) include(cmake/external_libs/gtest.cmake) add_subdirectory(tests) else() diff --git a/build.sh b/build.sh index 073d47d..3e4df13 100644 --- a/build.sh +++ b/build.sh @@ -151,6 +151,8 @@ build_parser() if [ "X$ENABLE_PARSER_UT" = "Xon" ]; then make ut_parser -j8 + elif [ "X$ENABLE_PARSER_ST" = "Xon" ]; then + make st_parser -j8 else make ${VERBOSE} -j${THREAD_NUM} && make install fi @@ -194,6 +196,17 @@ if [[ "X$ENABLE_PARSER_UT" = "Xon" || "X$ENABLE_PARSER_COV" = "Xon" ]]; then genhtml coverage.info fi +if [[ "X$ENABLE_PARSER_ST" = "Xon" ]]; then + cp ${BUILD_PATH}/tests/st/st_parser ${OUTPUT_PATH} + + RUN_TEST_CASE=${OUTPUT_PATH}/st_parser && ${RUN_TEST_CASE} + if [[ "$?" -ne 0 ]]; then + echo "!!! ST FAILED, PLEASE CHECK YOUR CHANGES !!!" + echo -e "\033[31m${RUN_TEST_CASE}\033[0m" + exit 1; + fi +fi + # generate output package in tar form, including ut/st libraries/executables generate_package() { @@ -236,7 +249,7 @@ generate_package() tar -cf parser_lib.tar fwkacllib acllib atc } -if [[ "X$ENABLE_PARSER_UT" = "Xoff" ]]; then +if [[ "X$ENABLE_PARSER_UT" = "Xoff" && "X$ENABLE_PARSER_ST" = "Xoff" ]]; then generate_package fi echo "---------------- Parser package archive generated ----------------" diff --git a/metadef b/metadef index 51418f6..708aff8 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 51418f61f26599c85bee2b57328afbbf1c9927c7 +Subproject commit 708aff80f34bd89c4d69c24a705aa50025e51c4f diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index fe5086f..c2842be 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -601,17 +601,17 @@ void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_ind } Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) { - if (ge::GetParserContext().user_out_nodes_top_vec.empty()) { + if (ge::GetParserContext().user_out_tensors.empty()) { return SUCCESS; } ge::GetParserContext().out_nodes_map.clear(); ge::GetParserContext().user_out_nodes.clear(); int32_t layer_count = proto_message.layer_size(); - const std::vector &user_out_nodes_top_vec = - ge::GetParserContext().user_out_nodes_top_vec; + const std::vector &user_out_tensors = + ge::GetParserContext().user_out_tensors; - for (const auto &top_name : user_out_nodes_top_vec) { + for (const auto &top_name : user_out_tensors) { bool find_node_falg = false; string layer_name; int32_t top_index = 0; @@ -1082,7 +1082,7 @@ Status CaffeModelParser::AddUserOutNodesTop() { string top_name = layer_iter->second[out_pair.second]; auto top_node_iter = node_map.find(out_pair.first); if (top_node_iter != node_map.end()) { - ge::GetParserContext().out_top_names.push_back(top_name); + ge::GetParserContext().out_tensor_names.push_back(top_name); GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str()); } ++index; @@ -1129,7 +1129,7 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes auto top_node_iter = node_map.find(layer.name()); GELOGI("output in top_blob: %s", layer.name().c_str()); if (top_node_iter != node_map.end()) { - ge::GetParserContext().out_top_names.push_back(top_origin); + ge::GetParserContext().out_tensor_names.push_back(top_origin); ge::GetParserContext().default_out_nodes.push_back(std::make_pair(layer.name(), (int32_t)i)); GELOGI("The top of out node [%s] is [%s]", layer.name().c_str(), top_origin.c_str()); } @@ -1389,13 +1389,13 @@ Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &la } string top_name = layer.top(0); - auto data_top_names = ge::GetParserContext().data_top_names; - if (find(data_top_names.begin(), data_top_names.end(), top_name) != data_top_names.end()) { + auto data_tensor_names = ge::GetParserContext().data_tensor_names; + if (find(data_tensor_names.begin(), data_tensor_names.end(), top_name) != data_tensor_names.end()) { ErrorManager::GetInstance().ATCReportErrMessage("E11036", {"topname"}, {top_name}); GELOGE(FAILED, "[Check][Node]Different data node can not have same top name: %s.", top_name.c_str()); return FAILED; } - ge::GetParserContext().data_top_names.push_back(top_name); + ge::GetParserContext().data_tensor_names.push_back(top_name); } return SUCCESS; @@ -1464,18 +1464,18 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap int32_t layer_count = proto_message.layer_size(); - if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) { + if (!ge::GetParserContext().user_out_tensors.empty()) { GELOGW("The out_put info has top_name items."); GE_RETURN_WITH_LOG_IF_ERROR(ParseOutputNodeTopInfo(proto_message), "[Parse][OutputNodeTopInfo] failed."); - ge::GetParserContext().user_out_nodes_top_vec.clear(); + ge::GetParserContext().user_out_tensors.clear(); } std::map inplace_blob_name_remapping; // Map of operator name and occurrence times std::map layer_name_map; - GetParserContext().data_top_names.clear(); + GetParserContext().data_tensor_names.clear(); // std::map> layer_params_map; // same param name set diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc index 731159d..bf63f35 100644 --- a/parser/common/acl_graph_parser_util.cc +++ b/parser/common/acl_graph_parser_util.cc @@ -52,6 +52,13 @@ const int kMaxFileSizeLimit = INT_MAX; const int kMaxBuffSize = 256; const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M +const uint32_t kSetOutputWithNodeAndIndex = 0x1; +const uint32_t kSetOutputWithTensorName = 0x2; +const uint32_t kSetOutputModeMixed = 0x3; +const std::unordered_set kSupportTensorAsOutput = { + domi::CAFFE, + domi::ONNX +}; static string GetSoPath() { Dl_info dl_info; @@ -263,14 +270,19 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { if (!out_nodes.empty()) { ge::GetParserContext().out_nodes_map.clear(); ge::GetParserContext().user_out_nodes.clear(); - ge::GetParserContext().user_out_nodes_top_vec.clear(); + ge::GetParserContext().user_out_tensors.clear(); + uint32_t set_output_mode = 0; vector nodes_v = StringUtils::Split(out_nodes, ';'); for (const string &node : nodes_v) { vector key_value_v = StringUtils::Split(node, ':'); if (key_value_v.size() != 2) { // The size must be 2. - if (key_value_v.size() == 1 && ge::GetParserContext().type == domi::CAFFE) { - ge::GetParserContext().user_out_nodes_top_vec.push_back(node); + if (key_value_v.size() == 1 && kSupportTensorAsOutput.count(ge::GetParserContext().type) > 0) { + set_output_mode |= kSetOutputWithTensorName; + if (set_output_mode == kSetOutputModeMixed) { + break; + } + ge::GetParserContext().user_out_tensors.push_back(node); continue; } ErrorManager::GetInstance().ATCReportErrMessage( @@ -281,12 +293,9 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { node.c_str()); return PARAM_INVALID; } - if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"out_nodes", out_nodes, "is not all index or top_name"}); - GELOGE(PARAM_INVALID, "[Check][Param] This out_nodes str must be all index or top_name, " - "while the actual input is %s", out_nodes.c_str()); - return PARAM_INVALID; + set_output_mode |= kSetOutputWithNodeAndIndex; + if (set_output_mode == kSetOutputModeMixed) { + break; } // stoi: The method may throw an exception: invalid_argument/out_of_range if (!CheckDigitStr(key_value_v[1])) { @@ -309,6 +318,13 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { } ge::GetParserContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index)); } + if (set_output_mode == kSetOutputModeMixed) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--out_nodes", out_nodes, "is not all index or top_name"}); + GELOGE(PARAM_INVALID, "[Parse][Param]This out_nodes str must be all index or tensor_name, " + "while the actual input is %s", out_nodes.c_str()); + return PARAM_INVALID; + } } } catch (std::invalid_argument &) { GELOGE(PARAM_INVALID, "[Check][Param] Invalid of out_nodes: %s ", out_nodes.c_str()); @@ -410,10 +426,11 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra return SUCCESS; } -void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, - std::vector &output_nodes_name) { +void AclGrphParseUtil::CreateOutputNodesInfo(std::vector> &output_nodes_info, + std::vector &output_nodes_name) { output_nodes_name.clear(); - if (ge::GetParserContext().out_top_names.empty()) { + auto &out_tensor_names = ge::GetParserContext().out_tensor_names; + if (out_tensor_names.empty()) { // tf process, no top name. for (const auto output_node_info : output_nodes_info) { std::string node_name = output_node_info.first->GetName(); @@ -422,13 +439,18 @@ void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vectorGetName(); + auto node = output_nodes_info[i].first; int32_t index = output_nodes_info[i].second; - if (i < ge::GetParserContext().out_top_names.size()) { - output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + - ge::GetParserContext().out_top_names[i]); + std::string node_name = node->GetName(); + if (i < out_tensor_names.size()) { + auto output_desc = node->GetOpDesc()->MutableOutputDesc(static_cast(index)); + (void)AttrUtils::SetStr(output_desc, ATTR_NAME_ORIGIN_OUTPUT_TENSOR_NAME, out_tensor_names[i]); + std::string output_name = node->GetName() + ":" + std::to_string(index) + ":" + out_tensor_names[i]; + output_nodes_name.push_back(output_name); + GELOGD("Output[%zu] name[%s]", i, output_name.c_str()); } else { GELOGW("Get top name of node [%s] fail.", node_name.c_str()); output_nodes_name.push_back(node_name + ":" + std::to_string(index)); @@ -469,7 +491,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, std::vector> &output_nodes_info) { std::vector> default_out_nodes = ge::GetParserContext().default_out_nodes; - if (ge::GetParserContext().type == domi::CAFFE && !default_out_nodes.empty()) { + if (!default_out_nodes.empty()) { for (uint32_t i = 0; i < default_out_nodes.size(); ++i) { ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first); if (out_node == nullptr) { @@ -543,7 +565,7 @@ domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph, return domi::FAILED; } } - GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); + CreateOutputNodesInfo(output_nodes_info, output_nodes_name); compute_graph->SetGraphOutNodesInfo(output_nodes_info); ge::GetParserContext().net_out_nodes = output_nodes_name; GELOGI("Set graph %s output node success.", graph.GetName().c_str()); diff --git a/parser/common/acl_graph_parser_util.h b/parser/common/acl_graph_parser_util.h index 4d11a35..e7190e9 100644 --- a/parser/common/acl_graph_parser_util.h +++ b/parser/common/acl_graph_parser_util.h @@ -50,8 +50,8 @@ class AclGrphParseUtil { bool parser_initialized = false; domi::Status CheckOptions(const std::map &parser_params); domi::Status GetOutputLeaf(NodePtr node, std::vector> &output_nodes_info); - void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, - std::vector &output_nodes_name); + void CreateOutputNodesInfo(std::vector> &output_nodes_info, + std::vector &output_nodes_name); void SetDefaultFormat(); domi::Status ParseAclOutputNodes(const std::string &out_nodes); domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16); diff --git a/parser/common/parser_utils.cc b/parser/common/parser_utils.cc index 58e178e..74febc9 100644 --- a/parser/common/parser_utils.cc +++ b/parser/common/parser_utils.cc @@ -71,7 +71,7 @@ Status HandleNewOp(const NodePtr &node, } } -Status ParserUtils::ExpandOneToManyGraph(Graph &graph) { +Status ParserUtils::ExpandOneToManyGraph(Graph &graph, OutputMapping &output_mapping) { GELOGD("Begin run ParserUtils::ExpandOneToManyGraph."); for (const auto &gn : graph.GetDirectNode()) { NodePtr n = NodeAdapter::GNode2Node(gn); @@ -95,7 +95,7 @@ Status ParserUtils::ExpandOneToManyGraph(Graph &graph) { GELOGE(FAILED, "[Invoke][ParseOpToGraphFunc]Get one to many graph failed for op:%s.", op.GetName().c_str()); return FAILED; } - ret = ExpandNodeToSubgraph(subgraph, n, graph); + ret = ExpandNodeToSubgraph(subgraph, n, graph, output_mapping); if (ret != SUCCESS) { GELOGE(FAILED, "[Invoke][ExpandNodeToSubgraph]Expand one to many graph failed for op:%s.", op.GetName().c_str()); return FAILED; @@ -105,7 +105,8 @@ Status ParserUtils::ExpandOneToManyGraph(Graph &graph) { return SUCCESS; } -Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph) { +Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph, + OutputMapping &output_mapping) { ComputeGraphPtr sub_compute_graph = GraphUtils::GetComputeGraph(subgraph); GE_CHECK_NOTNULL(sub_compute_graph); ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); @@ -135,7 +136,7 @@ Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &n // handle output context. std::vector> out_node_index = sub_compute_graph->GetGraphOutNodesInfo(); - ret = HandleOutputContext(node, out_node_index); + ret = HandleOutputContext(node, out_node_index, output_mapping); if (ret != SUCCESS) { GELOGE(FAILED, "[Run][HandleOutputContext] failed, node:%s.", node->GetName().c_str()); return FAILED; @@ -235,7 +236,8 @@ Status ParserUtils::HandleInputContext(const NodePtr &node, } Status ParserUtils::HandleOutputContext(const NodePtr &node, - const std::vector> &out_node_index) { + const std::vector> &out_node_index, + OutputMapping &output_mapping) { GE_CHECK_NOTNULL(node); GELOGD("The size of out node is %zu", out_node_index.size()); for (size_t index = 0; index < out_node_index.size(); index++) { @@ -247,6 +249,8 @@ Status ParserUtils::HandleOutputContext(const NodePtr &node, NodePtr out_node = out_node_index[index].first; int32_t out_index = out_node_index[index].second; GELOGD("Begin to handle output node:%s[%d] with index:%zu", out_node->GetName().c_str(), out_index, index); + std::string key = GenOutputKey({node->GetName(), index}); + output_mapping[key] = std::make_pair(out_node->GetName(), out_index); auto src_out_anchor = out_node->GetOutDataAnchor(out_index); // get out node's out anchor. GE_CHECK_NOTNULL(src_out_anchor); for (const auto &dest_in_anchor : node_out_anchor->GetPeerInDataAnchors()) { @@ -273,4 +277,26 @@ Status ParserUtils::HandleOutputContext(const NodePtr &node, } return SUCCESS; } + +string ParserUtils::GenOutputKey(const OutputNodeInfo &node_info) { + return node_info.first + ":" + std::to_string(node_info.second); +} + +void ParserUtils::UpdateOutputNodeInfo(const OutputMapping &final_output_nodes, OutputNodeInfo &output_node_info) { + std::string key = ParserUtils::GenOutputKey(output_node_info); + auto iter = final_output_nodes.find(key); + if (iter != final_output_nodes.end()) { + output_node_info = iter->second; + GELOGD("Update output node info, origin[%s], now[%s].", + key.c_str(), ParserUtils::GenOutputKey(output_node_info).c_str()); + } +} + +void ParserUtils::UpdateOutputCtx(const OutputMapping &final_output_nodes, OutputMapping &tensor_to_nodes) { + for (auto &tensor_to_node : tensor_to_nodes) { + std::string tensor_name = tensor_to_node.first; + auto &output_node_info = tensor_to_node.second; + UpdateOutputNodeInfo(final_output_nodes, output_node_info); + } +} } // namespace ge \ No newline at end of file diff --git a/parser/common/parser_utils.h b/parser/common/parser_utils.h index 6c19fb7..7563ee7 100644 --- a/parser/common/parser_utils.h +++ b/parser/common/parser_utils.h @@ -17,6 +17,7 @@ #ifndef PARSER_COMMON_PARSER_UTILS_H_ #define PARSER_COMMON_PARSER_UTILS_H_ +#include #include "graph/graph.h" #include "graph/node.h" #include "external/ge/ge_api_error_codes.h" @@ -24,15 +25,22 @@ namespace ge { class ParserUtils { public: - static Status ExpandOneToManyGraph(Graph &graph); + using OutputNodeInfo = std::pair; + using OutputMapping = std::unordered_map; + static Status ExpandOneToManyGraph(Graph &graph, OutputMapping &output_mapping); + static string GenOutputKey(const OutputNodeInfo &node_info); + static void UpdateOutputNodeInfo(const OutputMapping &final_output_nodes, OutputNodeInfo &output_node_info); + static void UpdateOutputCtx(const OutputMapping &final_output_nodes, OutputMapping &tensor_to_nodes); private: - static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph); + static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph, + OutputMapping &output_mapping); static Status HandleInputContext(const NodePtr &node, const std::vector &input_nodes, const ComputeGraphPtr &compute_graph); static Status HandleOutputContext(const NodePtr &node, - const std::vector> &out_node_index); + const std::vector> &out_node_index, + OutputMapping &output_mapping); }; } // namespace ge #endif // PARSER_COMMON_PARSER_UTILS_H_ diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 3c79a71..e27c7f6 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -360,7 +360,7 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, return SUCCESS; } -Status OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) { +void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) { int index = 0; for (int i = 0; i < onnx_graph.node_size(); i++) { ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); @@ -369,8 +369,6 @@ Status OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) { node->set_name(node_name); } } - - return SUCCESS; } Status OnnxModelParser::ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type) { @@ -676,7 +674,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve return SUCCESS; } -Status OnnxModelParser::GetGraphOutputs(std::vector>> &output_ops) { +Status OnnxModelParser::GetGraphOutputs(std::vector>> &output_ops, + ParserUtils::OutputMapping &out_tensor_to_nodes) { for (auto output_name : output_node_names_) { auto itr = outputs_map_.find(output_name); if (itr == outputs_map_.end()) { @@ -696,6 +695,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vectorsecond, vector{static_cast(index)}); + out_tensor_to_nodes[output_name] = std::make_pair(node_name, index); GELOGI("out node index %d, node:%s", index, node_name.c_str()); } } @@ -870,7 +870,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), "Run ProtoType Pass Failed"); - // 2. Get all inializer. + // 1. Get all inializer. std::map initializer_name_tensor; for (int i = 0; i < onnx_graph.initializer_size(); i++) { ge::onnx::TensorProto initializer_tensor = onnx_graph.initializer(i); @@ -880,7 +880,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP } } - // 3. Parse Input from graph. + // 2. Parse Input from graph. GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); Status ret = ParseInput(initializer_name_tensor, is_subgraph, onnx_graph); @@ -890,13 +890,14 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP } GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); - // 4. Parse Constant from graph. + // 3. Parse Constant from graph. ret = ParseInitializer(onnx_graph, initializer_name_tensor); if (ret != SUCCESS) { GELOGE(ret, "[Parse][Initializer] for onnx failed."); return ret; } + // 4. Get all output name form origin graph ret = ParseOutput(onnx_graph); if (ret != SUCCESS) { GELOGE(ret, "[Parse][Output] Parse output for onnx failed."); @@ -904,11 +905,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP } // 5. Update node name for node do not has name. - ret = UpdateAllNodeName(onnx_graph); - if (ret != SUCCESS) { - GELOGE(ret, "[Update][Name] of all node for onnx failed."); - return ret; - } + UpdateAllNodeName(onnx_graph); // 6 Precheck. ret = Prechecker(onnx_graph); @@ -950,19 +947,32 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP } graph.SetInputs(input_ops); + // 10. Get output info and set outpus for subgraph + std::vector>> output_ops; + ParserUtils::OutputMapping out_tensor_to_nodes; + ret = GetGraphOutputs(output_ops, out_tensor_to_nodes); + if (ret != SUCCESS) { + GELOGE(ret, "[Get][Outputs] failed."); + return ret; + } // root graph needn't set outputs. if(is_subgraph) { - std::vector>> output_ops; - ret = GetGraphOutputs(output_ops); + graph.SetOutputs(output_ops); + } + + // 11. Expand node to graph if need + ParserUtils::OutputMapping final_output_nodes; + GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph, final_output_nodes)); + + // 12. Set outputs info in ParserContext for root graph + if (!is_subgraph) { + ret = SetOutputsInfo(final_output_nodes, out_tensor_to_nodes); if (ret != SUCCESS) { - GELOGE(ret, "[Get][Outputs] failed."); + GELOGE(ret, "[Set][OutputsInfo] Graph:%s.", graph.GetName().c_str()); return ret; } - graph.SetOutputs(output_ops); } - GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); - GELOGI("Onnx model parser success."); return SUCCESS; } @@ -1048,6 +1058,50 @@ void OnnxModelParser::UpdateDataFormat(ge::Graph &graph) { return; } +Status OnnxModelParser::SetOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes, + const ParserUtils::OutputMapping &tensor_to_nodes) { + auto &user_specified_nodes = ge::GetParserContext().user_out_nodes; + if (!user_specified_nodes.empty()) { + GELOGI("User specified the output nodes with node_name and index."); + for (auto &output_node_info : user_specified_nodes) { + ParserUtils::UpdateOutputNodeInfo(final_output_nodes, output_node_info); + } + return SUCCESS; + } + + auto final_tensor_to_nodes = tensor_to_nodes; + ParserUtils::UpdateOutputCtx(final_output_nodes, final_tensor_to_nodes); + auto &user_specified_tensors = ge::GetParserContext().user_out_tensors; + auto &output_tensor_names = ge::GetParserContext().out_tensor_names; + output_tensor_names.clear(); + if (!user_specified_tensors.empty()) { + for (auto &tensor_name : user_specified_tensors) { + auto iter = final_tensor_to_nodes.find(tensor_name); + if (iter != final_tensor_to_nodes.end()) { + user_specified_nodes.emplace_back(iter->second); + output_tensor_names.emplace_back(tensor_name); + GELOGI("[UserSpecified]Add network output node[%s], index[%d], tensor name[%s].", + iter->second.first.c_str(), iter->second.second, tensor_name.c_str()); + } else { + REPORT_INNER_ERROR("E19999", "User specified tensor[%s] is not output of graph.", tensor_name.c_str()); + GELOGE(FAILED, "[Set][OutputsInfo]User specified tensor[%s] is not output of graph.", tensor_name.c_str()); + return FAILED; + } + } + return SUCCESS; + } + + // for default output + auto &default_out_nodes = ge::GetParserContext().default_out_nodes; + for (auto &tensor_name : output_node_names_) { + auto &output_node_info = final_tensor_to_nodes[tensor_name]; + default_out_nodes.emplace_back(output_node_info); + output_tensor_names.emplace_back(tensor_name); + GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].", + output_node_info.first.c_str(), output_node_info.second, tensor_name.c_str()); + } + return SUCCESS; +} } // namespace domi namespace domi { diff --git a/parser/onnx/onnx_parser.h b/parser/onnx/onnx_parser.h index b90c1a3..60997bd 100644 --- a/parser/onnx/onnx_parser.h +++ b/parser/onnx/onnx_parser.h @@ -38,6 +38,7 @@ #include "omg/parser/model_parser.h" #include "omg/parser/op_parser.h" #include "omg/parser/weights_parser.h" +#include "common/parser_utils.h" #include "proto/onnx/ge_onnx.pb.h" namespace ge { @@ -80,7 +81,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, std::map &initializer_name_tensor); - Status UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph); + void UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph); Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); @@ -94,7 +95,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { Status GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector &input_ops); - Status GetGraphOutputs(std::vector>> &outputs); + Status GetGraphOutputs(std::vector>> &outputs, + ParserUtils::OutputMapping &out_tensor_to_nodes); Status Prechecker(ge::onnx::GraphProto &onnx_graph); @@ -115,6 +117,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, std::map &name_to_onnx_graph); + Status SetOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes, + const ParserUtils::OutputMapping &tensor_to_nodes); + std::map ori_to_om_type_; std::map domain_verseion_; diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 8fde8fc..f188a97 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -1494,7 +1494,9 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro GE_RETURN_IF_ERROR(AddEdges(graph)); Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph); - GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph)); + ParserUtils::OutputMapping final_output_nodes; + GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph, final_output_nodes)); + GE_RETURN_IF_ERROR(UpdateOutputsInfo(final_output_nodes)); GE_RETURN_IF_ERROR(RemoveIsolateNode(graph)); GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph)); GE_RETURN_IF_ERROR(graph->TopologicalSorting()); @@ -2304,7 +2306,9 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, ret = AddEdges(graph); Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph); - GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph)); + ParserUtils::OutputMapping final_output_nodes; + GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph, final_output_nodes)); + GE_RETURN_IF_ERROR(UpdateOutputsInfo(final_output_nodes)); DeleteFuisonNodeDef(); GE_CHK_STATUS_EXEC(ret, return ret, "AddEdges failed"); @@ -4020,6 +4024,16 @@ Status TensorFlowModelParser::CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compu } return SUCCESS; } + +Status TensorFlowModelParser::UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes) { + auto &user_specified_nodes = ge::GetParserContext().user_out_nodes; + if (!user_specified_nodes.empty()) { + for (auto &output_node_info : user_specified_nodes) { + ParserUtils::UpdateOutputNodeInfo(final_output_nodes, output_node_info); + } + } + return SUCCESS; +} } // namespace ge namespace domi { diff --git a/parser/tensorflow/tensorflow_parser.h b/parser/tensorflow/tensorflow_parser.h index 23bf750..94f5bb3 100644 --- a/parser/tensorflow/tensorflow_parser.h +++ b/parser/tensorflow/tensorflow_parser.h @@ -44,6 +44,7 @@ #include "proto/tensorflow/graph_library.pb.h" #include "external/register/scope/scope_fusion_pass_register.h" #include "scope/scope_pass_manager.h" +#include "common/parser_utils.h" using ge::ScopePassManager; using domi::tensorflow::GraphDef; @@ -647,6 +648,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr &op_parser); Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph); + static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes); /** * save diff --git a/tests/depends/error_manager/src/error_manager_stub.cc b/tests/depends/error_manager/src/error_manager_stub.cc index 4d06caf..1fba95a 100644 --- a/tests/depends/error_manager/src/error_manager_stub.cc +++ b/tests/depends/error_manager/src/error_manager_stub.cc @@ -50,7 +50,7 @@ int ErrorManager::ReportInterErrMessage(std::string error_code, const std::strin const std::string &ErrorManager::GetLogHeader() { - static const std::string kLogHeader("GeUtStub"); + static const std::string kLogHeader("[ParserStub]"); return kLogHeader; } diff --git a/tests/depends/mmpa/src/mmpa_stub.cc b/tests/depends/mmpa/src/mmpa_stub.cc index dd332a7..d0097bc 100644 --- a/tests/depends/mmpa/src/mmpa_stub.cc +++ b/tests/depends/mmpa/src/mmpa_stub.cc @@ -275,10 +275,10 @@ INT32 mmGetPid() } INT32 mmDup2(INT32 oldFd, INT32 newFd) { - return -1; + return 0; } INT32 mmDup(INT32 fd) { - return -1; + return 0; } diff --git a/tests/depends/slog/src/slog_stub.cc b/tests/depends/slog/src/slog_stub.cc index edc245b..1ce0b1b 100644 --- a/tests/depends/slog/src/slog_stub.cc +++ b/tests/depends/slog/src/slog_stub.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,20 +15,53 @@ */ #include "toolchain/slog.h" +#include "toolchain/plog.h" #include #include -#include void dav_log(int module_id, const char *fmt, ...) {} -void DlogErrorInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } +static int log_level = DLOG_ERROR; -void DlogWarnInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } +#define __DO_PRINT() \ + do { \ + const int FMT_BUFF_SIZE = 1024; \ + char fmt_buff[FMT_BUFF_SIZE] = {0}; \ + va_list valist; \ + va_start(valist, fmt); \ + vsnprintf(fmt_buff, FMT_BUFF_SIZE, fmt, valist); \ + va_end(valist); \ + printf("%s \n", fmt_buff); \ + } while (0) -void DlogInfoInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } +void DlogErrorInner(int module_id, const char *fmt, ...) { + if (log_level > DLOG_ERROR) { + return; + } + __DO_PRINT(); +} + +void DlogWarnInner(int module_id, const char *fmt, ...) { + if (log_level > DLOG_WARN) { + return; + } + __DO_PRINT(); +} -void DlogDebugInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } +void DlogInfoInner(int module_id, const char *fmt, ...) { + if (log_level > DLOG_INFO) { + return; + } + __DO_PRINT(); +} + +void DlogDebugInner(int module_id, const char *fmt, ...) { + if (log_level > DLOG_DEBUG) { + return; + } + __DO_PRINT(); +} void DlogEventInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } @@ -38,11 +71,25 @@ void DlogWithKVInner(int module_id, int level, KeyValue *pst_kv_array, int kv_nu dav_log(module_id, fmt); } -int dlog_setlevel(int module_id, int level, int enable_event) { return DLOG_DEBUG; } +int dlog_setlevel(int module_id, int level, int enable_event) { + log_level = level; + return log_level; +} + +int dlog_getlevel(int module_id, int *enable_event) { return log_level; } -int dlog_getlevel(int module_id, int *enable_event) { return DLOG_DEBUG; } +int CheckLogLevel(int moduleId, int log_level_check) { return log_level >= log_level_check; } -int CheckLogLevel(int moduleId, int logLevel) -{ - return 1; -} +/** + * @ingroup plog + * @brief DlogReportInitialize: init log in service process before all device setting. + * @return: 0: SUCCEED, others: FAILED + */ +int DlogReportInitialize() { return 0; } + +/** + * @ingroup plog + * @brief DlogReportFinalize: release log resource in service process after all device reset. + * @return: 0: SUCCEED, others: FAILED + */ +int DlogReportFinalize() { return 0; } diff --git a/tests/st/CMakeLists.txt b/tests/st/CMakeLists.txt new file mode 100644 index 0000000..a835c25 --- /dev/null +++ b/tests/st/CMakeLists.txt @@ -0,0 +1,356 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +project(st_parser) + +set(CMAKE_CXX_STANDARD 11) + +################################################################################ +set(PARSER_PROTO_LIST + "${PARSER_DIR}/metadef/proto/om.proto" + "${PARSER_DIR}/metadef/proto/ge_ir.proto" + "${PARSER_DIR}/metadef/proto/task.proto" + "${PARSER_DIR}/metadef/proto/tensorflow/attr_value.proto" + "${PARSER_DIR}/metadef/proto/tensorflow/function.proto" + "${PARSER_DIR}/metadef/proto/tensorflow/graph.proto" + "${PARSER_DIR}/metadef/proto/tensorflow/graph_library.proto" + "${PARSER_DIR}/metadef/proto/tensorflow/node_def.proto" + "${PARSER_DIR}/metadef/proto/tensorflow/op_def.proto" + "${PARSER_DIR}/metadef/proto/tensorflow/resource_handle.proto" + "${PARSER_DIR}/metadef/proto/tensorflow/tensor.proto" + "${PARSER_DIR}/metadef/proto/tensorflow/tensor_shape.proto" + "${PARSER_DIR}/metadef/proto/tensorflow/types.proto" + "${PARSER_DIR}/metadef/proto/tensorflow/versions.proto" + "${PARSER_DIR}/metadef/proto/caffe/caffe.proto" + "${PARSER_DIR}/metadef/proto/onnx/ge_onnx.proto" + #"${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto" +) + +protobuf_generate(ge PARSER_PROTO_SRCS PARSER_PROTO_HDRS ${PARSER_PROTO_LIST}) + +############ libst_parser_proto.a ############ +add_library(st_parser_proto STATIC + ${PARSER_PROTO_HDRS} ${PARSER_PROTO_SRCS} +) + +target_compile_definitions(st_parser_proto PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + google=ascend_private +) + +target_compile_options(st_parser_proto PRIVATE + -O2 -g -fno-common +) + +target_link_libraries(st_parser_proto PRIVATE + $ + ascend_protobuf +) + + +################################################################################ +set(DUPLICATE_PROTO_LIST + "${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto" +) + +protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST}) + +################################################################################ +set(MATEDEF_SRC_FILES + "${PARSER_DIR}/metadef/graph/aligned_ptr.cc" + "${PARSER_DIR}/metadef/graph/anchor.cc" + "${PARSER_DIR}/metadef/graph/ascend_string.cc" + "${PARSER_DIR}/metadef/graph/attr_value.cc" + "${PARSER_DIR}/metadef/graph/buffer.cc" + "${PARSER_DIR}/metadef/graph/compute_graph.cc" + "${PARSER_DIR}/metadef/graph/debug/graph_debug.cc" + "${PARSER_DIR}/metadef/graph/detail/attributes_holder.cc" + "${PARSER_DIR}/metadef/graph/format_refiner.cc" + "${PARSER_DIR}/metadef/graph/ge_attr_define.cc" + "${PARSER_DIR}/metadef/graph/ge_tensor.cc" + "${PARSER_DIR}/metadef/graph/gnode.cc" + "${PARSER_DIR}/metadef/graph/graph.cc" + "${PARSER_DIR}/metadef/graph/inference_context.cc" + "${PARSER_DIR}/metadef/graph/model.cc" + "${PARSER_DIR}/metadef/graph/model_serialize.cc" + "${PARSER_DIR}/metadef/graph/node.cc" + "${PARSER_DIR}/metadef/graph/op_desc.cc" + "${PARSER_DIR}/metadef/graph/operator.cc" + "${PARSER_DIR}/metadef/graph/operator_factory.cc" + "${PARSER_DIR}/metadef/graph/operator_factory_impl.cc" + "${PARSER_DIR}/metadef/graph/opsproto/opsproto_manager.cc" + "${PARSER_DIR}/metadef/graph/option/ge_context.cc" + "${PARSER_DIR}/metadef/graph/option/ge_local_context.cc" + "${PARSER_DIR}/metadef/graph/ref_relation.cc" + "${PARSER_DIR}/metadef/graph/runtime_inference_context.cc" + "${PARSER_DIR}/metadef/graph/shape_refiner.cc" + "${PARSER_DIR}/metadef/graph/tensor.cc" + "${PARSER_DIR}/metadef/graph/types.cc" + "${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc" + "${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc" + "${PARSER_DIR}/metadef/graph/utils/graph_utils.cc" + "${PARSER_DIR}/metadef/graph/utils/node_utils.cc" + "${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc" + "${PARSER_DIR}/metadef/graph/utils/tensor_utils.cc" + "${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc" + "${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc" + "${PARSER_DIR}/metadef/graph/utils/type_utils.cc" + "${PARSER_DIR}/metadef/ops/op_imp.cpp" + "${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc" + "${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" + "${PARSER_DIR}/metadef/third_party/transformer/src/transfer_shape_according_to_format.cc" +) + +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${PARSER_DIR}/metadef/inc) +include_directories(${PARSER_DIR}/metadef/inc/graph) +include_directories(${PARSER_DIR}/metadef/inc/external) +include_directories(${PARSER_DIR}/metadef/inc/external/graph) +include_directories(${PARSER_DIR}/metadef/graph) +include_directories(${PARSER_DIR}/metadef/third_party) +include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc) +include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external) +include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external/ge) +include_directories(${PARSER_DIR}/metadef/third_party/fwkacllib/inc) +include_directories(${PARSER_DIR}/metadef/third_party/transformer/inc) +include_directories(${PARSER_DIR}/metadef) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) +include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) + +############ libst_parser_graph.a ############ +add_library(st_parser_graph STATIC + ${MATEDEF_SRC_FILES} ${PARSER_PROTO_HDRS} ${DUP_PROTO_HDRS} +) + +target_compile_definitions(st_parser_graph PRIVATE + google=ascend_private +) + +target_compile_options(st_parser_graph PRIVATE + -O2 -g -fno-common +) + +target_link_libraries(st_parser_graph PRIVATE + $ + c_sec ascend_protobuf +) + + +################################################################################ +set(REGISTER_SRC_FILES + "${PARSER_DIR}/metadef/register/auto_mapping_util.cpp" + "${PARSER_DIR}/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc" + "${PARSER_DIR}/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc" + "${PARSER_DIR}/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc" + "${PARSER_DIR}/metadef/register/graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc" + "${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/fusion_pass_registry.cc" + "${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/fusion_pattern.cc" + "${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.cc" + "${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc" + "${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc" + "${PARSER_DIR}/metadef/register/host_cpu_context.cc" + "${PARSER_DIR}/metadef/register/infer_data_slice_registry.cc" + "${PARSER_DIR}/metadef/register/ops_kernel_builder_registry.cc" + "${PARSER_DIR}/metadef/register/op_kernel_registry.cpp" + "${PARSER_DIR}/metadef/register/op_tiling.cpp" + "${PARSER_DIR}/metadef/register/op_tiling_registry.cpp" + "${PARSER_DIR}/metadef/register/register.cpp" + "${PARSER_DIR}/metadef/register/register_format_transfer.cc" + "${PARSER_DIR}/metadef/register/register_pass.cpp" + "${PARSER_DIR}/metadef/register/scope/scope_graph.cc" + "${PARSER_DIR}/metadef/register/scope/scope_pass.cc" + "${PARSER_DIR}/metadef/register/scope/scope_pass_registry.cc" + "${PARSER_DIR}/metadef/register/scope/scope_pattern.cc" + "${PARSER_DIR}/metadef/register/scope/scope_util.cc" + "${PARSER_DIR}/metadef/register/tensor_assign.cpp" + "${PARSER_DIR}/metadef/register/prototype_pass_registry.cc" +) + +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) +include_directories(${PARSER_DIR}/metadef) +include_directories(${PARSER_DIR}/metadef/graph) +include_directories(${PARSER_DIR}/metadef/inc) +include_directories(${PARSER_DIR}/metadef/inc/external) +include_directories(${PARSER_DIR}/metadef/inc/register) +include_directories(${PARSER_DIR}/metadef/third_party/fwkacllib/inc) +include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc) +include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external) +include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) + +############ libst_parser_register.a ############ +add_library(st_parser_register STATIC + ${REGISTER_SRC_FILES} ${PARSER_PROTO_HDRS} +) + +target_compile_definitions(st_parser_register PRIVATE + google=ascend_private +) + +target_compile_options(st_parser_register PRIVATE + -O2 -g -fno-common +) + +target_link_libraries(st_parser_register PRIVATE + $ + c_sec ascend_protobuf json +) + + +################################################################################ +set(PARSER_SRC_FILES + "${PARSER_DIR}/parser/caffe/caffe_custom_parser_adapter.cc" + "${PARSER_DIR}/parser/caffe/caffe_data_parser.cc" + "${PARSER_DIR}/parser/caffe/caffe_op_parser.cc" + "${PARSER_DIR}/parser/caffe/caffe_parser.cc" + "${PARSER_DIR}/parser/caffe/caffe_reshape_parser.cc" + "${PARSER_DIR}/parser/common/acl_graph_parser_util.cc" + "${PARSER_DIR}/parser/common/convert/pb2json.cc" + "${PARSER_DIR}/parser/common/convert/message2operator.cc" + "${PARSER_DIR}/parser/common/data_op_parser.cc" + "${PARSER_DIR}/parser/common/model_saver.cc" + "${PARSER_DIR}/parser/common/op_def/arg_op.cc" + "${PARSER_DIR}/parser/common/op_def/constant_op.cc" + "${PARSER_DIR}/parser/common/op_def/defs.cc" + "${PARSER_DIR}/parser/common/op_def/fill_op.cc" + "${PARSER_DIR}/parser/common/op_def/frameworkop_op.cc" + "${PARSER_DIR}/parser/common/op_def/ir_pb_converter.cc" + "${PARSER_DIR}/parser/common/op_def/no_op_op.cc" + "${PARSER_DIR}/parser/common/op_def/operator.cc" + "${PARSER_DIR}/parser/common/op_def/op_schema.cc" + "${PARSER_DIR}/parser/common/op_def/ref_switch_op.cc" + "${PARSER_DIR}/parser/common/op_def/shape_n_op.cc" + "${PARSER_DIR}/parser/common/op_def/variable_op.cc" + "${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_op.cc" + "${PARSER_DIR}/parser/common/op_map.cc" + "${PARSER_DIR}/parser/common/op_parser_factory.cc" + "${PARSER_DIR}/parser/common/parser_api.cc" + "${PARSER_DIR}/parser/common/parser_factory.cc" + "${PARSER_DIR}/parser/common/parser_fp16_t.cc" + "${PARSER_DIR}/parser/common/parser_inner_ctx.cc" + "${PARSER_DIR}/parser/common/parser_types.cc" + "${PARSER_DIR}/parser/common/parser_utils.cc" + "${PARSER_DIR}/parser/common/pass_manager.cc" + "${PARSER_DIR}/parser/common/pre_checker.cc" + "${PARSER_DIR}/parser/common/proto_file_parser.cc" + "${PARSER_DIR}/parser/common/prototype_pass_manager.cc" + "${PARSER_DIR}/parser/common/register_tbe.cc" + "${PARSER_DIR}/parser/common/tbe_plugin_loader.cc" + "${PARSER_DIR}/parser/common/thread_pool.cc" + "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" + "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" + "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" + "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" + "${PARSER_DIR}/parser/onnx/onnx_parser.cc" + "${PARSER_DIR}/parser/onnx/onnx_util.cc" + "${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc" + "${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc" + "${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc" + "${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc" + "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" + "${PARSER_DIR}/parser/tensorflow/scope/scope_pass_manager.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_arg_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_constant_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_custom_parser_adapter.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_data_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_enter_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_fill_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_frameworkop_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_fusionop_util.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_op_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_identity_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_merge_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_no_op_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_ref_switch_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_reshape_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_shape_n_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_squeeze_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_util.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_variable_v2_parser.cc" + "${PARSER_DIR}/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc" +) + +# include directories +include_directories(${CMAKE_CURRENT_LIST_DIR}) +include_directories(${CMAKE_BINARY_DIR}/proto/ge) +include_directories(${PARSER_DIR}) +include_directories(${PARSER_DIR}/inc) +include_directories(${PARSER_DIR}/parser) +include_directories(${PARSER_DIR}/parser/onnx) +include_directories(${PARSER_DIR}/tests) +include_directories(${PARSER_DIR}/metadef/inc) +include_directories(${PARSER_DIR}/metadef/inc/external) +include_directories(${PARSER_DIR}/metadef/inc/register) +include_directories(${PARSER_DIR}/metadef/third_party/fwkacllib/inc) +include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc) +include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external) +include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) + + +set(PARSER_ST_FILES + "parser_st_utils.cc" + "testcase/test_main.cc" + "testcase/test_onnx_parser.cc" + "testcase/test_caffe_parser.cc" + "testcase/test_tensorflow_parser.cc" +) + +############ libst_parser_common.a ############ +add_library(st_parser_common STATIC + ${PARSER_SRC_FILES} ${PARSER_PROTO_HDRS} +) + +target_compile_definitions(st_parser_common PRIVATE + google=ascend_private +) + +target_compile_options(st_parser_common PRIVATE + -g --coverage -fprofile-arcs -ftest-coverage + -Werror=format +) + +target_link_libraries(st_parser_common PRIVATE + $ + st_parser_proto st_parser_graph c_sec + ascend_protobuf + json +) + + +################################################################################ +add_executable(st_parser + ${PARSER_ST_FILES} ${PARSER_PROTO_SRCS} +) + +target_compile_options(st_parser PRIVATE + -g +) + +target_compile_definitions(st_parser PRIVATE + google=ascend_private +) + +target_link_libraries(st_parser + $ + st_parser_proto + -Wl,--whole-archive st_parser_common -Wl,--no-whole-archive + st_parser_graph st_parser_register error_manager_stub mmpa_stub attr_util_stub + gtest gtest_main slog_stub ascend_protobuf c_sec -lrt -ldl -lgcov +) diff --git a/tests/st/parser_st_utils.cc b/tests/st/parser_st_utils.cc new file mode 100644 index 0000000..ba04212 --- /dev/null +++ b/tests/st/parser_st_utils.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "st/parser_st_utils.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +void ParerSTestsUtils::ClearParserInnerCtx() { + ge::GetParserContext().input_nodes_format_map.clear(); + ge::GetParserContext().output_formats.clear(); + ge::GetParserContext().user_input_dims.clear(); + ge::GetParserContext().input_dims.clear(); + ge::GetParserContext().op_conf_map.clear(); + ge::GetParserContext().user_out_nodes.clear(); + ge::GetParserContext().default_out_nodes.clear(); + ge::GetParserContext().out_nodes_map.clear(); + ge::GetParserContext().user_out_tensors.clear(); + ge::GetParserContext().net_out_nodes.clear(); + ge::GetParserContext().out_tensor_names.clear(); + ge::GetParserContext().data_tensor_names.clear(); + ge::GetParserContext().is_dynamic_input = false; + ge::GetParserContext().train_flag = false; + ge::GetParserContext().format = domi::DOMI_TENSOR_ND; + ge::GetParserContext().type = domi::FRAMEWORK_RESERVED; + ge::GetParserContext().run_mode = GEN_OM_MODEL; + ge::GetParserContext().custom_proto_path = ""; + ge::GetParserContext().caffe_proto_path = ""; + ge::GetParserContext().enable_scope_fusion_passes = ""; + GELOGI("Clear parser inner context successfully."); +} +} // namespace ge diff --git a/tests/st/parser_st_utils.h b/tests/st/parser_st_utils.h new file mode 100644 index 0000000..50b6d06 --- /dev/null +++ b/tests/st/parser_st_utils.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_TESTS_UT_PARSER_H_ +#define GE_PARSER_TESTS_UT_PARSER_H_ + +#include "framework/omg/parser/parser_inner_ctx.h" + +namespace ge { +class ParerSTestsUtils { + public: + static void ClearParserInnerCtx(); +}; +} // namespace ge + +#endif // GE_PARSER_TESTS_UT_PARSER_H_ diff --git a/tests/st/testcase/origin_models/caffe_abs.pbtxt b/tests/st/testcase/origin_models/caffe_abs.pbtxt new file mode 100644 index 0000000..c032cff --- /dev/null +++ b/tests/st/testcase/origin_models/caffe_abs.pbtxt @@ -0,0 +1,14 @@ +name: "TestAbs" +layer { + name: "data" + type: "Input" + top: "data" + input_param { shape: { dim: 64 dim: 1 dim: 28 dim: 28 } } +} + +layer { + name: "abs" + type: "AbsVal" + bottom: "data" + top: "abs_out" +} \ No newline at end of file diff --git a/tests/st/testcase/origin_models/onnx_conv2d.onnx b/tests/st/testcase/origin_models/onnx_conv2d.onnx new file mode 100644 index 0000000..aa823ed Binary files /dev/null and b/tests/st/testcase/origin_models/onnx_conv2d.onnx differ diff --git a/tests/st/testcase/origin_models/onnx_if.onnx b/tests/st/testcase/origin_models/onnx_if.onnx new file mode 100644 index 0000000..ff2230a Binary files /dev/null and b/tests/st/testcase/origin_models/onnx_if.onnx differ diff --git a/tests/st/testcase/origin_models/tf_add.pb b/tests/st/testcase/origin_models/tf_add.pb new file mode 100644 index 0000000..b9a42a3 --- /dev/null +++ b/tests/st/testcase/origin_models/tf_add.pb @@ -0,0 +1,13 @@ + +8 + Placeholder Placeholder* +dtype0* +shape: +: + Placeholder_1 Placeholder* +dtype0* +shape: +6 + +add_test_1Add Placeholder Placeholder_1* +T0"† \ No newline at end of file diff --git a/tests/st/testcase/test_caffe_parser.cc b/tests/st/testcase/test_caffe_parser.cc new file mode 100644 index 0000000..668e16c --- /dev/null +++ b/tests/st/testcase/test_caffe_parser.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "parser/common/op_parser_factory.h" +#include "graph/operator_reg.h" +#include "register/op_registry.h" +#include "parser/common/register_tbe.h" +#include "framework/omg/parser/model_parser.h" +#include "framework/omg/parser/parser_factory.h" +#include "external/parser/caffe_parser.h" +#include "st/parser_st_utils.h" +#include "external/ge/ge_api_types.h" + +namespace ge { +class STestCaffeParser : public testing::Test { + protected: + void SetUp() { + ParerSTestsUtils::ClearParserInnerCtx(); + RegisterCustomOp(); + } + + void TearDown() {} + + public: + void RegisterCustomOp(); +}; + +static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { + return SUCCESS; +} +void STestCaffeParser::RegisterCustomOp() { + REGISTER_CUSTOM_OP("Data") + .FrameworkType(domi::CAFFE) + .OriginOpType("Input") + .ParseParamsFn(ParseParams); + + REGISTER_CUSTOM_OP("Abs") + .FrameworkType(domi::CAFFE) + .OriginOpType("AbsVal") + .ParseParamsFn(ParseParams); + + std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; + for (auto reg_data : reg_datas) { + OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegistry::Instance()->Register(reg_data); + } + domi::OpRegistry::Instance()->registrationDatas.clear(); +} + +namespace { +REG_OP(Data) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(index, Int, 0) + .OP_END_FACTORY_REG(Data) + +REG_OP(Abs) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(Abs) +} + +TEST_F(STestCaffeParser, caffe_parser_user_output_with_default) { + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/origin_models/caffe_abs.pbtxt"; + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); + ASSERT_NE(model_parser, nullptr); + ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared("tmp_graph"); + ASSERT_NE(compute_graph, nullptr); + ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); + auto ret = model_parser->Parse(model_file.c_str(), graph); + ASSERT_EQ(ret, GRAPH_SUCCESS); + AclGrphParseUtil acl_graph_parse_util; + std::map parser_params; + auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); + ASSERT_EQ(status, SUCCESS); + + auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); + ASSERT_EQ(output_nodes_info.size(), 1); + EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs"); + EXPECT_EQ((output_nodes_info.at(0).second), 0); + auto &net_out_name = ge::GetParserContext().net_out_nodes; + ASSERT_EQ(net_out_name.size(), 1); + EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); +} + +} // namespace ge \ No newline at end of file diff --git a/tests/st/testcase/test_main.cc b/tests/st/testcase/test_main.cc new file mode 100644 index 0000000..828f7c3 --- /dev/null +++ b/tests/st/testcase/test_main.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +using namespace std; + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + int ret = RUN_ALL_TESTS(); + std::cout << "Finish parser st." << std::endl; + return ret; +} diff --git a/tests/st/testcase/test_onnx_parser.cc b/tests/st/testcase/test_onnx_parser.cc new file mode 100644 index 0000000..98e4ef7 --- /dev/null +++ b/tests/st/testcase/test_onnx_parser.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "parser/common/op_parser_factory.h" +#include "graph/operator_reg.h" +#include "register/op_registry.h" +#include "parser/common/register_tbe.h" +#include "external/parser/onnx_parser.h" +#include "st/parser_st_utils.h" +#include "external/ge/ge_api_types.h" + +namespace ge { +class STestOnnxParser : public testing::Test { + protected: + void SetUp() { + ParerSTestsUtils::ClearParserInnerCtx(); + RegisterCustomOp(); + } + + void TearDown() {} + + public: + void RegisterCustomOp(); +}; + +static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { + return SUCCESS; +} + +static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) { + return SUCCESS; +} + +Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) { + domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func = + domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX); + if (auto_mapping_subgraph_index_func == nullptr) { + std::cout<<"auto mapping if subgraph func is nullptr!"< Status { + parent_index = data_index + 1; + return SUCCESS; + }, + [&](int output_index, int &parent_index) -> Status { + parent_index = output_index; + return SUCCESS; + }); +} + +void STestOnnxParser::RegisterCustomOp() { + REGISTER_CUSTOM_OP("Conv2D") + .FrameworkType(domi::ONNX) + .OriginOpType("ai.onnx::11::Conv") + .ParseParamsFn(ParseParams); + + // register if op info to GE + REGISTER_CUSTOM_OP("If") + .FrameworkType(domi::ONNX) + .OriginOpType({"ai.onnx::9::If", + "ai.onnx::10::If", + "ai.onnx::11::If", + "ai.onnx::12::If", + "ai.onnx::13::If"}) + .ParseParamsFn(ParseParams) + .ParseParamsByOperatorFn(ParseParamByOpFunc) + .ParseSubgraphPostFn(ParseSubgraphPostFnIf); + + REGISTER_CUSTOM_OP("Add") + .FrameworkType(domi::ONNX) + .OriginOpType("ai.onnx::11::Add") + .ParseParamsFn(ParseParams); + + REGISTER_CUSTOM_OP("Identity") + .FrameworkType(domi::ONNX) + .OriginOpType("ai.onnx::11::Identity") + .ParseParamsFn(ParseParams); + + std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; + for (auto reg_data : reg_datas) { + OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegistry::Instance()->Register(reg_data); + } + domi::OpRegistry::Instance()->registrationDatas.clear(); +} + +namespace { +REG_OP(Data) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(index, Int, 0) + .OP_END_FACTORY_REG(Data) + +REG_OP(Const) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(value, Tensor, Tensor()) + .OP_END_FACTORY_REG(Const) + +REG_OP(Conv2D) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) + .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dilations, ListInt, {1, 1, 1, 1}) + .ATTR(groups, Int, 1) + .ATTR(data_format, String, "NHWC") + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(Conv2D) + +REG_OP(If) + .INPUT(cond, TensorType::ALL()) + .DYNAMIC_INPUT(input, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(then_branch) + .GRAPH(else_branch) + .OP_END_FACTORY_REG(If) + +REG_OP(Add) + .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .OP_END_FACTORY_REG(Add) + +REG_OP(Identity) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(Identity) +} + +TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) { + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/origin_models/onnx_conv2d.onnx"; + std::map parser_params; + ge::Graph graph; + auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); + ASSERT_EQ(ret, GRAPH_SUCCESS); + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); + ASSERT_EQ(output_nodes_info.size(), 1); + EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0"); + EXPECT_EQ((output_nodes_info.at(0).second), 0); + auto &net_out_name = ge::GetParserContext().net_out_nodes; + ASSERT_EQ(net_out_name.size(), 1); + EXPECT_EQ(net_out_name.at(0), "Conv_0:0:y"); +} + +TEST_F(STestOnnxParser, onnx_parser_if_node) { + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/origin_models/onnx_if.onnx"; + std::map parser_params; + ge::Graph graph; + auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); + EXPECT_EQ(ret, GRAPH_SUCCESS); +} +} // namespace ge \ No newline at end of file diff --git a/tests/st/testcase/test_tensorflow_parser.cc b/tests/st/testcase/test_tensorflow_parser.cc new file mode 100644 index 0000000..92f7698 --- /dev/null +++ b/tests/st/testcase/test_tensorflow_parser.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "parser/common/op_parser_factory.h" +#include "parser/tensorflow/tensorflow_parser.h" +#include "graph/operator_reg.h" +#include "register/op_registry.h" +#include "parser/common/register_tbe.h" +#include "external/parser/tensorflow_parser.h" +#include "st/parser_st_utils.h" + +namespace ge { +class STestTensorflowParser : public testing::Test { + protected: + void SetUp() { + ParerSTestsUtils::ClearParserInnerCtx(); + } + + void TearDown() {} + + public: + void RegisterCustomOp(); +}; + +static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { + return SUCCESS; +} + +void STestTensorflowParser::RegisterCustomOp() { + REGISTER_CUSTOM_OP("Add") + .FrameworkType(domi::TENSORFLOW) + .OriginOpType("Add") + .ParseParamsFn(ParseParams); + + std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; + for (auto reg_data : reg_datas) { + OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegistry::Instance()->Register(reg_data); + } + domi::OpRegistry::Instance()->registrationDatas.clear(); +} + +namespace { +REG_OP(Data) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(index, Int, 0) + .OP_END_FACTORY_REG(Data) + +REG_OP(Add) + .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .OP_END_FACTORY_REG(Add) +} + +TEST_F(STestTensorflowParser, tensorflow_parser_success) { + RegisterCustomOp(); + + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/origin_models/tf_add.pb"; + std::map parser_params; + ge::Graph graph; + auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); + ASSERT_EQ(ret, SUCCESS); + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); + ASSERT_EQ(output_nodes_info.size(), 1); + EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "add_test_1"); + EXPECT_EQ((output_nodes_info.at(0).second), 0); + auto &net_out_name = ge::GetParserContext().net_out_nodes; + ASSERT_EQ(net_out_name.size(), 1); + EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); +} +} // namespace ge \ No newline at end of file diff --git a/tests/ut/parser/CMakeLists.txt b/tests/ut/parser/CMakeLists.txt index 74d7c84..b93e231 100644 --- a/tests/ut/parser/CMakeLists.txt +++ b/tests/ut/parser/CMakeLists.txt @@ -296,6 +296,7 @@ include_directories(${PARSER_DIR}) include_directories(${PARSER_DIR}/inc) include_directories(${PARSER_DIR}/parser) include_directories(${PARSER_DIR}/parser/onnx) +include_directories(${PARSER_DIR}/tests) include_directories(${PARSER_DIR}/metadef/inc) include_directories(${PARSER_DIR}/metadef/inc/external) include_directories(${PARSER_DIR}/metadef/inc/register) @@ -306,7 +307,10 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) set(PARSER_UT_FILES + "parser_ut_utils.cc" + "testcase/common/acl_graph_parser_unittest.cc" "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" + "testcase/caffe_parser_testcase/caffe_parser_unittest.cc" "testcase/onnx_parser_testcase/message2operator_unittest.cc" "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" "testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc" diff --git a/tests/ut/parser/parser_ut_utils.cc b/tests/ut/parser/parser_ut_utils.cc new file mode 100644 index 0000000..202ca1d --- /dev/null +++ b/tests/ut/parser/parser_ut_utils.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ut/parser/parser_ut_utils.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +void ParerUTestsUtils::ClearParserInnerCtx() { + ge::GetParserContext().input_nodes_format_map.clear(); + ge::GetParserContext().output_formats.clear(); + ge::GetParserContext().user_input_dims.clear(); + ge::GetParserContext().input_dims.clear(); + ge::GetParserContext().op_conf_map.clear(); + ge::GetParserContext().user_out_nodes.clear(); + ge::GetParserContext().default_out_nodes.clear(); + ge::GetParserContext().out_nodes_map.clear(); + ge::GetParserContext().user_out_tensors.clear(); + ge::GetParserContext().net_out_nodes.clear(); + ge::GetParserContext().out_tensor_names.clear(); + ge::GetParserContext().data_tensor_names.clear(); + ge::GetParserContext().is_dynamic_input = false; + ge::GetParserContext().train_flag = false; + ge::GetParserContext().format = domi::DOMI_TENSOR_ND; + ge::GetParserContext().type = domi::FRAMEWORK_RESERVED; + ge::GetParserContext().run_mode = GEN_OM_MODEL; + ge::GetParserContext().custom_proto_path = ""; + ge::GetParserContext().caffe_proto_path = ""; + ge::GetParserContext().enable_scope_fusion_passes = ""; + GELOGI("Clear parser inner context successfully."); +} +} // namespace ge diff --git a/tests/ut/parser/parser_ut_utils.h b/tests/ut/parser/parser_ut_utils.h new file mode 100644 index 0000000..38596b6 --- /dev/null +++ b/tests/ut/parser/parser_ut_utils.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_TESTS_UT_PARSER_H_ +#define GE_PARSER_TESTS_UT_PARSER_H_ + +#include "framework/omg/parser/parser_inner_ctx.h" + +namespace ge { +class ParerUTestsUtils { + public: + static void ClearParserInnerCtx(); +}; +} // namespace ge + +#endif // GE_PARSER_TESTS_UT_PARSER_H_ diff --git a/tests/ut/parser/testcase/caffe_parser_testcase/caffe_model/caffe_abs.pbtxt b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_model/caffe_abs.pbtxt new file mode 100644 index 0000000..c032cff --- /dev/null +++ b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_model/caffe_abs.pbtxt @@ -0,0 +1,14 @@ +name: "TestAbs" +layer { + name: "data" + type: "Input" + top: "data" + input_param { shape: { dim: 64 dim: 1 dim: 28 dim: 28 } } +} + +layer { + name: "abs" + type: "AbsVal" + bottom: "data" + top: "abs_out" +} \ No newline at end of file diff --git a/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc new file mode 100644 index 0000000..93e87be --- /dev/null +++ b/tests/ut/parser/testcase/caffe_parser_testcase/caffe_parser_unittest.cc @@ -0,0 +1,158 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "parser/common/op_parser_factory.h" +#include "graph/operator_reg.h" +#include "external/graph/types.h" +#include "register/op_registry.h" +#include "parser/common/register_tbe.h" +#include "framework/omg/parser/model_parser.h" +#include "framework/omg/parser/parser_factory.h" +#include "external/parser/caffe_parser.h" +#include "ut/parser/parser_ut_utils.h" +#include "external/ge/ge_api_types.h" + +namespace ge { +class UtestCaffeParser : public testing::Test { + protected: + void SetUp() { + ParerUTestsUtils::ClearParserInnerCtx(); + RegisterCustomOp(); + } + + void TearDown() {} + + public: + void RegisterCustomOp(); +}; + +static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { + return SUCCESS; +} +void UtestCaffeParser::RegisterCustomOp() { + REGISTER_CUSTOM_OP("Data") + .FrameworkType(domi::CAFFE) + .OriginOpType("Input") + .ParseParamsFn(ParseParams); + + REGISTER_CUSTOM_OP("Abs") + .FrameworkType(domi::CAFFE) + .OriginOpType("AbsVal") + .ParseParamsFn(ParseParams); + + std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; + for (auto reg_data : reg_datas) { + OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegistry::Instance()->Register(reg_data); + } + domi::OpRegistry::Instance()->registrationDatas.clear(); +} + +namespace { +REG_OP(Data) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(index, Int, 0) + .OP_END_FACTORY_REG(Data) + +REG_OP(Abs) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(Abs) +} + +TEST_F(UtestCaffeParser, caffe_parser_user_output_with_name_and_index) { + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/caffe_model/caffe_abs.pbtxt"; + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); + ASSERT_NE(model_parser, nullptr); + ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared("tmpGraph"); + ASSERT_NE(compute_graph, nullptr); + ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); + + ge::GetParserContext().user_out_nodes.push_back({"abs", 0}); + auto ret = model_parser->Parse(model_file.c_str(), graph); + ASSERT_EQ(ret, GRAPH_SUCCESS); + AclGrphParseUtil acl_graph_parse_util; + std::map parser_params; + auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); + ASSERT_EQ(status, SUCCESS); + + auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); + ASSERT_EQ(output_nodes_info.size(), 1); + EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs"); + EXPECT_EQ((output_nodes_info.at(0).second), 0); + auto &net_out_name = ge::GetParserContext().net_out_nodes; + ASSERT_EQ(net_out_name.size(), 1); + EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); +} + +TEST_F(UtestCaffeParser, caffe_parser_user_output_with_top_name) { + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/caffe_model/caffe_abs.pbtxt"; + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); + ASSERT_NE(model_parser, nullptr); + ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared("tmpGraph"); + ASSERT_NE(compute_graph, nullptr); + ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); + + ge::GetParserContext().user_out_tensors.push_back("abs_out"); + auto ret = model_parser->Parse(model_file.c_str(), graph); + ASSERT_EQ(ret, GRAPH_SUCCESS); + AclGrphParseUtil acl_graph_parse_util; + std::map parser_params; + auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); + ASSERT_EQ(status, SUCCESS); + + auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); + ASSERT_EQ(output_nodes_info.size(), 1); + EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs"); + EXPECT_EQ((output_nodes_info.at(0).second), 0); + auto &net_out_name = ge::GetParserContext().net_out_nodes; + ASSERT_EQ(net_out_name.size(), 1); + EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); +} + +TEST_F(UtestCaffeParser, caffe_parser_user_output_with_default) { + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/caffe_model/caffe_abs.pbtxt"; + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); + ASSERT_NE(model_parser, nullptr); + ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared("tmpGraph"); + ASSERT_NE(compute_graph, nullptr); + ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); + auto ret = model_parser->Parse(model_file.c_str(), graph); + ASSERT_EQ(ret, GRAPH_SUCCESS); + AclGrphParseUtil acl_graph_parse_util; + std::map parser_params; + auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); + ASSERT_EQ(status, SUCCESS); + + auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); + ASSERT_EQ(output_nodes_info.size(), 1); + EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs"); + EXPECT_EQ((output_nodes_info.at(0).second), 0); + auto &net_out_name = ge::GetParserContext().net_out_nodes; + ASSERT_EQ(net_out_name.size(), 1); + EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); +} + +} // namespace ge \ No newline at end of file diff --git a/tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc b/tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc new file mode 100644 index 0000000..68269b8 --- /dev/null +++ b/tests/ut/parser/testcase/common/acl_graph_parser_unittest.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "parser/common/op_parser_factory.h" +#include "graph/operator_reg.h" +#include "external/graph/types.h" +#include "register/op_registry.h" +#include "parser/common/register_tbe.h" +#include "external/parser/onnx_parser.h" +#include "ut/parser/parser_ut_utils.h" +#include "external/ge/ge_api_types.h" + +namespace ge { +class UtestAclGraphParser : public testing::Test { + protected: + void SetUp() { + + } + void TearDown() {} +}; + +TEST_F(UtestAclGraphParser, test_parse_acl_output_nodes) { + AclGrphParseUtil acl_graph_parse_util; + string graph_name; + // case 1: Normal with 'node and index' + ParerUTestsUtils::ClearParserInnerCtx(); + GetParserContext().type = domi::ONNX; + std::map out_nodes_with_node_and_index = { + {AscendString(ge::ir_option::OUT_NODES), AscendString("Out1:0;Out2:1")}}; + ParerUTestsUtils::ClearParserInnerCtx(); + auto ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_with_node_and_index, graph_name); + ASSERT_EQ(ret, SUCCESS); + EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 2); + EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 2); + EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 0); + + // case 2: Normal with 'tensor name' + ParerUTestsUtils::ClearParserInnerCtx(); + GetParserContext().type = domi::ONNX; + std::map out_nodes_with_tensor_name = { + {AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out_tensor_2")}}; + ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_with_tensor_name, graph_name); + ASSERT_EQ(ret, SUCCESS); + EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0); + EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0); + EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 2); + + // case 3: Failed with 'node and index' before 'tensor name' + ParerUTestsUtils::ClearParserInnerCtx(); + GetParserContext().type = domi::ONNX; + std::map out_nodes_mode_mixex_pre = { + {AscendString(ge::ir_option::OUT_NODES), AscendString("Out1:0;Out2:1;Out_tensor_1;Out_tensor_2")}}; + ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_pre, graph_name); + ASSERT_EQ(ret, PARAM_INVALID); + EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 2); + EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 2); + EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 0); + + // case 4: Failed with 'node and index' inserted in 'tensor name' + ParerUTestsUtils::ClearParserInnerCtx(); + GetParserContext().type = domi::ONNX; + std::map out_nodes_mode_mixex_mid = { + {AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out1:0;Out2:1;Out_tensor_2")}}; + ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_mid, graph_name); + ASSERT_EQ(ret, PARAM_INVALID); + EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0); + EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0); + EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 1); + + // case 5: Failed with 'node and index' after 'tensor name' + ParerUTestsUtils::ClearParserInnerCtx(); + GetParserContext().type = domi::ONNX; + std::map out_nodes_mode_mixex_post = { + {AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out_tensor_2;Out1:0;Out2:1")}}; + ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_post, graph_name); + ASSERT_EQ(ret, PARAM_INVALID); + EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0); + EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0); + EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 2); + +} +} // namespace ge \ No newline at end of file diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc index 678b8a6..732b00a 100644 --- a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -22,12 +22,16 @@ #include "register/op_registry.h" #include "parser/common/register_tbe.h" #include "external/parser/onnx_parser.h" - +#include "ut/parser/parser_ut_utils.h" +#include "external/ge/ge_api_types.h" namespace ge { class UtestOnnxParser : public testing::Test { protected: - void SetUp() {} + void SetUp() { + ParerUTestsUtils::ClearParserInnerCtx(); + RegisterCustomOp(); + } void TearDown() {} @@ -152,28 +156,81 @@ REG_OP(Identity) .OP_END_FACTORY_REG(Identity) } -TEST_F(UtestOnnxParser, onnx_parser_success) { - RegisterCustomOp(); +TEST_F(UtestOnnxParser, onnx_parser_if_node) { + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/onnx_model/if.onnx"; + std::map parser_params; + ge::Graph graph; + auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); + EXPECT_EQ(ret, GRAPH_SUCCESS); +} +TEST_F(UtestOnnxParser, onnx_parser_user_output_with_name_and_index) { std::string case_dir = __FILE__; case_dir = case_dir.substr(0, case_dir.find_last_of("/")); std::string model_file = case_dir + "/onnx_model/conv2d.onnx"; std::map parser_params; + parser_params.insert({AscendString(ge::ir_option::OUT_NODES), AscendString("Conv_0:0")}); ge::Graph graph; auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); - EXPECT_EQ(ret, domi::SUCCESS); + ASSERT_EQ(ret, GRAPH_SUCCESS); + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); + ASSERT_EQ(output_nodes_info.size(), 1); + EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0"); + EXPECT_EQ((output_nodes_info.at(0).second), 0); + auto &net_out_name = ge::GetParserContext().net_out_nodes; + ASSERT_EQ(net_out_name.size(), 1); + EXPECT_EQ(net_out_name.at(0), "Conv_0:0"); } -TEST_F(UtestOnnxParser, onnx_parser_if_node) { - RegisterCustomOp(); +TEST_F(UtestOnnxParser, onnx_parser_user_output_with_tensor) { + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/onnx_model/conv2d.onnx"; + std::map parser_params; + parser_params.insert({AscendString(ge::ir_option::OUT_NODES), AscendString("y")}); + ge::Graph graph; + auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); + ASSERT_EQ(ret, GRAPH_SUCCESS); + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); + ASSERT_EQ(output_nodes_info.size(), 1); + EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0"); + EXPECT_EQ((output_nodes_info.at(0).second), 0); + auto &net_out_name = ge::GetParserContext().net_out_nodes; + ASSERT_EQ(net_out_name.size(), 1); + EXPECT_EQ(net_out_name.at(0), "Conv_0:0:y"); +} +TEST_F(UtestOnnxParser, onnx_parser_user_output_with_default) { std::string case_dir = __FILE__; case_dir = case_dir.substr(0, case_dir.find_last_of("/")); - std::string model_file = case_dir + "/onnx_model/if.onnx"; + std::string model_file = case_dir + "/onnx_model/conv2d.onnx"; + std::map parser_params; + ge::Graph graph; + auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); + ASSERT_EQ(ret, GRAPH_SUCCESS); + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); + ASSERT_EQ(output_nodes_info.size(), 1); + EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0"); + EXPECT_EQ((output_nodes_info.at(0).second), 0); + auto &net_out_name = ge::GetParserContext().net_out_nodes; + ASSERT_EQ(net_out_name.size(), 1); + EXPECT_EQ(net_out_name.at(0), "Conv_0:0:y"); +} + +TEST_F(UtestOnnxParser, onnx_parser_user_output_with_tensor_failed) { + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/onnx_model/conv2d.onnx"; std::map parser_params; + parser_params.insert({AscendString(ge::ir_option::OUT_NODES), AscendString("not_exist_output")}); ge::Graph graph; auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); - EXPECT_EQ(ret, domi::SUCCESS); + EXPECT_EQ(ret, FAILED); } } // namespace ge \ No newline at end of file diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index f597289..86f2b8f 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -24,12 +24,14 @@ #include "register/op_registry.h" #include "parser/common/register_tbe.h" #include "external/parser/tensorflow_parser.h" - +#include "ut/parser/parser_ut_utils.h" namespace ge { class UtestTensorflowParser : public testing::Test { protected: - void SetUp() {} + void SetUp() { + ParerUTestsUtils::ClearParserInnerCtx(); + } void TearDown() {}