| @@ -601,17 +601,17 @@ void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_ind | |||
| } | |||
| Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) { | |||
| if (ge::GetParserContext().user_out_nodes_top_vec.empty()) { | |||
| if (ge::GetParserContext().user_out_tensors.empty()) { | |||
| return SUCCESS; | |||
| } | |||
| ge::GetParserContext().out_nodes_map.clear(); | |||
| ge::GetParserContext().user_out_nodes.clear(); | |||
| int32_t layer_count = proto_message.layer_size(); | |||
| const std::vector<string> &user_out_nodes_top_vec = | |||
| ge::GetParserContext().user_out_nodes_top_vec; | |||
| const std::vector<string> &user_out_tensors = | |||
| ge::GetParserContext().user_out_tensors; | |||
| for (const auto &top_name : user_out_nodes_top_vec) { | |||
| for (const auto &top_name : user_out_tensors) { | |||
| bool find_node_falg = false; | |||
| string layer_name; | |||
| int32_t top_index = 0; | |||
| @@ -1082,7 +1082,7 @@ Status CaffeModelParser::AddUserOutNodesTop() { | |||
| string top_name = layer_iter->second[out_pair.second]; | |||
| auto top_node_iter = node_map.find(out_pair.first); | |||
| if (top_node_iter != node_map.end()) { | |||
| ge::GetParserContext().out_top_names.push_back(top_name); | |||
| ge::GetParserContext().out_tensor_names.push_back(top_name); | |||
| GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str()); | |||
| } | |||
| ++index; | |||
| @@ -1129,7 +1129,7 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes | |||
| auto top_node_iter = node_map.find(layer.name()); | |||
| GELOGI("output in top_blob: %s", layer.name().c_str()); | |||
| if (top_node_iter != node_map.end()) { | |||
| ge::GetParserContext().out_top_names.push_back(top_origin); | |||
| ge::GetParserContext().out_tensor_names.push_back(top_origin); | |||
| ge::GetParserContext().default_out_nodes.push_back(std::make_pair(layer.name(), (int32_t)i)); | |||
| GELOGI("The top of out node [%s] is [%s]", layer.name().c_str(), top_origin.c_str()); | |||
| } | |||
| @@ -1389,13 +1389,13 @@ Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &la | |||
| } | |||
| string top_name = layer.top(0); | |||
| auto data_top_names = ge::GetParserContext().data_top_names; | |||
| if (find(data_top_names.begin(), data_top_names.end(), top_name) != data_top_names.end()) { | |||
| auto data_tensor_names = ge::GetParserContext().data_tensor_names; | |||
| if (find(data_tensor_names.begin(), data_tensor_names.end(), top_name) != data_tensor_names.end()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E11036", {"topname"}, {top_name}); | |||
| GELOGE(FAILED, "[Check][Node]Different data node can not have same top name: %s.", top_name.c_str()); | |||
| return FAILED; | |||
| } | |||
| ge::GetParserContext().data_top_names.push_back(top_name); | |||
| ge::GetParserContext().data_tensor_names.push_back(top_name); | |||
| } | |||
| return SUCCESS; | |||
| @@ -1464,18 +1464,18 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||
| int32_t layer_count = proto_message.layer_size(); | |||
| if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) { | |||
| if (!ge::GetParserContext().user_out_tensors.empty()) { | |||
| GELOGW("The out_put info has top_name items."); | |||
| GE_RETURN_WITH_LOG_IF_ERROR(ParseOutputNodeTopInfo(proto_message), | |||
| "[Parse][OutputNodeTopInfo] failed."); | |||
| ge::GetParserContext().user_out_nodes_top_vec.clear(); | |||
| ge::GetParserContext().user_out_tensors.clear(); | |||
| } | |||
| std::map<std::string, std::string> inplace_blob_name_remapping; | |||
| // Map of operator name and occurrence times | |||
| std::map<std::string, int32_t> layer_name_map; | |||
| GetParserContext().data_top_names.clear(); | |||
| GetParserContext().data_tensor_names.clear(); | |||
| // <layername,paramnames> | |||
| std::map<std::string, std::vector<std::string>> layer_params_map; | |||
| // same param name set <paramnames,layernames> | |||
| @@ -52,6 +52,13 @@ const int kMaxFileSizeLimit = INT_MAX; | |||
| const int kMaxBuffSize = 256; | |||
| const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. | |||
| const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M | |||
| const uint32_t kSetOutputWithNodeAndIndex = 0x1; | |||
| const uint32_t kSetOutputWithTensorName = 0x2; | |||
| const uint32_t kSetOutputModeMixed = 0x3; | |||
| const std::unordered_set<domi::FrameworkType> kSupportTensorAsOutput = { | |||
| domi::CAFFE, | |||
| domi::ONNX | |||
| }; | |||
| static string GetSoPath() { | |||
| Dl_info dl_info; | |||
| @@ -263,14 +270,19 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { | |||
| if (!out_nodes.empty()) { | |||
| ge::GetParserContext().out_nodes_map.clear(); | |||
| ge::GetParserContext().user_out_nodes.clear(); | |||
| ge::GetParserContext().user_out_nodes_top_vec.clear(); | |||
| ge::GetParserContext().user_out_tensors.clear(); | |||
| uint32_t set_output_mode = 0; | |||
| vector<string> nodes_v = StringUtils::Split(out_nodes, ';'); | |||
| for (const string &node : nodes_v) { | |||
| vector<string> key_value_v = StringUtils::Split(node, ':'); | |||
| if (key_value_v.size() != 2) { // The size must be 2. | |||
| if (key_value_v.size() == 1 && ge::GetParserContext().type == domi::CAFFE) { | |||
| ge::GetParserContext().user_out_nodes_top_vec.push_back(node); | |||
| if (key_value_v.size() == 1 && kSupportTensorAsOutput.count(ge::GetParserContext().type) > 0) { | |||
| set_output_mode |= kSetOutputWithTensorName; | |||
| if (set_output_mode == kSetOutputModeMixed) { | |||
| break; | |||
| } | |||
| ge::GetParserContext().user_out_tensors.push_back(node); | |||
| continue; | |||
| } | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| @@ -281,12 +293,9 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { | |||
| node.c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||
| {"out_nodes", out_nodes, "is not all index or top_name"}); | |||
| GELOGE(PARAM_INVALID, "[Check][Param] This out_nodes str must be all index or top_name, " | |||
| "while the actual input is %s", out_nodes.c_str()); | |||
| return PARAM_INVALID; | |||
| set_output_mode |= kSetOutputWithNodeAndIndex; | |||
| if (set_output_mode == kSetOutputModeMixed) { | |||
| break; | |||
| } | |||
| // stoi: The method may throw an exception: invalid_argument/out_of_range | |||
| if (!CheckDigitStr(key_value_v[1])) { | |||
| @@ -309,6 +318,13 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { | |||
| } | |||
| ge::GetParserContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index)); | |||
| } | |||
| if (set_output_mode == kSetOutputModeMixed) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||
| {"--out_nodes", out_nodes, "is not all index or top_name"}); | |||
| GELOGE(PARAM_INVALID, "[Parse][Param]This out_nodes str must be all index or tensor_name, " | |||
| "while the actual input is %s", out_nodes.c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| } | |||
| } catch (std::invalid_argument &) { | |||
| GELOGE(PARAM_INVALID, "[Check][Param] Invalid of out_nodes: %s ", out_nodes.c_str()); | |||
| @@ -410,10 +426,11 @@ domi::Status AclGrphParseUtil::ParseAclInputFp16Nodes(const ComputeGraphPtr &gra | |||
| return SUCCESS; | |||
| } | |||
| void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name) { | |||
| void AclGrphParseUtil::CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name) { | |||
| output_nodes_name.clear(); | |||
| if (ge::GetParserContext().out_top_names.empty()) { | |||
| auto &out_tensor_names = ge::GetParserContext().out_tensor_names; | |||
| if (out_tensor_names.empty()) { | |||
| // tf process, no top name. | |||
| for (const auto output_node_info : output_nodes_info) { | |||
| std::string node_name = output_node_info.first->GetName(); | |||
| @@ -422,13 +439,18 @@ void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::Node | |||
| } | |||
| return; | |||
| } | |||
| // caffe process, need add top name after node_name:index | |||
| // Need add top name after node_name:index | |||
| for (size_t i = 0; i < output_nodes_info.size(); ++i) { | |||
| std::string node_name = output_nodes_info[i].first->GetName(); | |||
| auto node = output_nodes_info[i].first; | |||
| int32_t index = output_nodes_info[i].second; | |||
| if (i < ge::GetParserContext().out_top_names.size()) { | |||
| output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + | |||
| ge::GetParserContext().out_top_names[i]); | |||
| std::string node_name = node->GetName(); | |||
| if (i < out_tensor_names.size()) { | |||
| auto output_desc = node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(index)); | |||
| (void)AttrUtils::SetStr(output_desc, ATTR_NAME_ORIGIN_OUTPUT_TENSOR_NAME, out_tensor_names[i]); | |||
| std::string output_name = node->GetName() + ":" + std::to_string(index) + ":" + out_tensor_names[i]; | |||
| output_nodes_name.push_back(output_name); | |||
| GELOGD("Output[%zu] name[%s]", i, output_name.c_str()); | |||
| } else { | |||
| GELOGW("Get top name of node [%s] fail.", node_name.c_str()); | |||
| output_nodes_name.push_back(node_name + ":" + std::to_string(index)); | |||
| @@ -469,7 +491,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, | |||
| domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph, | |||
| std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) { | |||
| std::vector<std::pair<std::string, int32_t>> default_out_nodes = ge::GetParserContext().default_out_nodes; | |||
| if (ge::GetParserContext().type == domi::CAFFE && !default_out_nodes.empty()) { | |||
| if (!default_out_nodes.empty()) { | |||
| for (uint32_t i = 0; i < default_out_nodes.size(); ++i) { | |||
| ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first); | |||
| if (out_node == nullptr) { | |||
| @@ -543,7 +565,7 @@ domi::Status AclGrphParseUtil::SetOutputNodeInfo(ge::Graph &graph, | |||
| return domi::FAILED; | |||
| } | |||
| } | |||
| GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); | |||
| CreateOutputNodesInfo(output_nodes_info, output_nodes_name); | |||
| compute_graph->SetGraphOutNodesInfo(output_nodes_info); | |||
| ge::GetParserContext().net_out_nodes = output_nodes_name; | |||
| GELOGI("Set graph %s output node success.", graph.GetName().c_str()); | |||
| @@ -50,8 +50,8 @@ class AclGrphParseUtil { | |||
| bool parser_initialized = false; | |||
| domi::Status CheckOptions(const std::map<AscendString, AscendString> &parser_params); | |||
| domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||
| void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name); | |||
| void CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name); | |||
| void SetDefaultFormat(); | |||
| domi::Status ParseAclOutputNodes(const std::string &out_nodes); | |||
| domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16); | |||
| @@ -71,7 +71,7 @@ Status HandleNewOp(const NodePtr &node, | |||
| } | |||
| } | |||
| Status ParserUtils::ExpandOneToManyGraph(Graph &graph) { | |||
| Status ParserUtils::ExpandOneToManyGraph(Graph &graph, OutputMapping &output_mapping) { | |||
| GELOGD("Begin run ParserUtils::ExpandOneToManyGraph."); | |||
| for (const auto &gn : graph.GetDirectNode()) { | |||
| NodePtr n = NodeAdapter::GNode2Node(gn); | |||
| @@ -95,7 +95,7 @@ Status ParserUtils::ExpandOneToManyGraph(Graph &graph) { | |||
| GELOGE(FAILED, "[Invoke][ParseOpToGraphFunc]Get one to many graph failed for op:%s.", op.GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| ret = ExpandNodeToSubgraph(subgraph, n, graph); | |||
| ret = ExpandNodeToSubgraph(subgraph, n, graph, output_mapping); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(FAILED, "[Invoke][ExpandNodeToSubgraph]Expand one to many graph failed for op:%s.", op.GetName().c_str()); | |||
| return FAILED; | |||
| @@ -105,7 +105,8 @@ Status ParserUtils::ExpandOneToManyGraph(Graph &graph) { | |||
| return SUCCESS; | |||
| } | |||
| Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph) { | |||
| Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph, | |||
| OutputMapping &output_mapping) { | |||
| ComputeGraphPtr sub_compute_graph = GraphUtils::GetComputeGraph(subgraph); | |||
| GE_CHECK_NOTNULL(sub_compute_graph); | |||
| ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | |||
| @@ -135,7 +136,7 @@ Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &n | |||
| // handle output context. | |||
| std::vector<std::pair<NodePtr, int32_t>> out_node_index = sub_compute_graph->GetGraphOutNodesInfo(); | |||
| ret = HandleOutputContext(node, out_node_index); | |||
| ret = HandleOutputContext(node, out_node_index, output_mapping); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(FAILED, "[Run][HandleOutputContext] failed, node:%s.", node->GetName().c_str()); | |||
| return FAILED; | |||
| @@ -235,7 +236,8 @@ Status ParserUtils::HandleInputContext(const NodePtr &node, | |||
| } | |||
| Status ParserUtils::HandleOutputContext(const NodePtr &node, | |||
| const std::vector<std::pair<NodePtr, int32_t>> &out_node_index) { | |||
| const std::vector<std::pair<NodePtr, int32_t>> &out_node_index, | |||
| OutputMapping &output_mapping) { | |||
| GE_CHECK_NOTNULL(node); | |||
| GELOGD("The size of out node is %zu", out_node_index.size()); | |||
| for (size_t index = 0; index < out_node_index.size(); index++) { | |||
| @@ -247,6 +249,8 @@ Status ParserUtils::HandleOutputContext(const NodePtr &node, | |||
| NodePtr out_node = out_node_index[index].first; | |||
| int32_t out_index = out_node_index[index].second; | |||
| GELOGD("Begin to handle output node:%s[%d] with index:%zu", out_node->GetName().c_str(), out_index, index); | |||
| std::string key = GenOutputKey({node->GetName(), index}); | |||
| output_mapping[key] = std::make_pair(out_node->GetName(), out_index); | |||
| auto src_out_anchor = out_node->GetOutDataAnchor(out_index); // get out node's out anchor. | |||
| GE_CHECK_NOTNULL(src_out_anchor); | |||
| for (const auto &dest_in_anchor : node_out_anchor->GetPeerInDataAnchors()) { | |||
| @@ -273,4 +277,26 @@ Status ParserUtils::HandleOutputContext(const NodePtr &node, | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| string ParserUtils::GenOutputKey(const OutputNodeInfo &node_info) { | |||
| return node_info.first + ":" + std::to_string(node_info.second); | |||
| } | |||
| void ParserUtils::UpdateOutputNodeInfo(const OutputMapping &final_output_nodes, OutputNodeInfo &output_node_info) { | |||
| std::string key = ParserUtils::GenOutputKey(output_node_info); | |||
| auto iter = final_output_nodes.find(key); | |||
| if (iter != final_output_nodes.end()) { | |||
| output_node_info = iter->second; | |||
| GELOGD("Update output node info, origin[%s], now[%s].", | |||
| key.c_str(), ParserUtils::GenOutputKey(output_node_info).c_str()); | |||
| } | |||
| } | |||
| void ParserUtils::UpdateOutputCtx(const OutputMapping &final_output_nodes, OutputMapping &tensor_to_nodes) { | |||
| for (auto &tensor_to_node : tensor_to_nodes) { | |||
| std::string tensor_name = tensor_to_node.first; | |||
| auto &output_node_info = tensor_to_node.second; | |||
| UpdateOutputNodeInfo(final_output_nodes, output_node_info); | |||
| } | |||
| } | |||
| } // namespace ge | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef PARSER_COMMON_PARSER_UTILS_H_ | |||
| #define PARSER_COMMON_PARSER_UTILS_H_ | |||
| #include <unordered_map> | |||
| #include "graph/graph.h" | |||
| #include "graph/node.h" | |||
| #include "external/ge/ge_api_error_codes.h" | |||
| @@ -24,15 +25,22 @@ | |||
| namespace ge { | |||
| class ParserUtils { | |||
| public: | |||
| static Status ExpandOneToManyGraph(Graph &graph); | |||
| using OutputNodeInfo = std::pair<std::string, int32_t>; | |||
| using OutputMapping = std::unordered_map<std::string, OutputNodeInfo>; | |||
| static Status ExpandOneToManyGraph(Graph &graph, OutputMapping &output_mapping); | |||
| static string GenOutputKey(const OutputNodeInfo &node_info); | |||
| static void UpdateOutputNodeInfo(const OutputMapping &final_output_nodes, OutputNodeInfo &output_node_info); | |||
| static void UpdateOutputCtx(const OutputMapping &final_output_nodes, OutputMapping &tensor_to_nodes); | |||
| private: | |||
| static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph); | |||
| static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph, | |||
| OutputMapping &output_mapping); | |||
| static Status HandleInputContext(const NodePtr &node, | |||
| const std::vector<NodePtr> &input_nodes, | |||
| const ComputeGraphPtr &compute_graph); | |||
| static Status HandleOutputContext(const NodePtr &node, | |||
| const std::vector<std::pair<NodePtr, int32_t>> &out_node_index); | |||
| const std::vector<std::pair<NodePtr, int32_t>> &out_node_index, | |||
| OutputMapping &output_mapping); | |||
| }; | |||
| } // namespace ge | |||
| #endif // PARSER_COMMON_PARSER_UTILS_H_ | |||
| @@ -360,7 +360,7 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||
| return SUCCESS; | |||
| } | |||
| Status OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) { | |||
| void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) { | |||
| int index = 0; | |||
| for (int i = 0; i < onnx_graph.node_size(); i++) { | |||
| ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | |||
| @@ -369,8 +369,6 @@ Status OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) { | |||
| node->set_name(node_name); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status OnnxModelParser::ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type) { | |||
| @@ -676,7 +674,8 @@ Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::ve | |||
| return SUCCESS; | |||
| } | |||
| Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops) { | |||
| Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops, | |||
| ParserUtils::OutputMapping &out_tensor_to_nodes) { | |||
| for (auto output_name : output_node_names_) { | |||
| auto itr = outputs_map_.find(output_name); | |||
| if (itr == outputs_map_.end()) { | |||
| @@ -696,6 +695,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vec | |||
| } | |||
| int index = node_name_index.second; | |||
| output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)}); | |||
| out_tensor_to_nodes[output_name] = std::make_pair(node_name, index); | |||
| GELOGI("out node index %d, node:%s", index, node_name.c_str()); | |||
| } | |||
| } | |||
| @@ -870,7 +870,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||
| GE_RETURN_WITH_LOG_IF_ERROR(ProtoTypePassManager::Instance().Run(&onnx_graph, domi::ONNX), | |||
| "Run ProtoType Pass Failed"); | |||
| // 2. Get all inializer. | |||
| // 1. Get all inializer. | |||
| std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; | |||
| for (int i = 0; i < onnx_graph.initializer_size(); i++) { | |||
| ge::onnx::TensorProto initializer_tensor = onnx_graph.initializer(i); | |||
| @@ -880,7 +880,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||
| } | |||
| } | |||
| // 3. Parse Input from graph. | |||
| // 2. Parse Input from graph. | |||
| GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); | |||
| Status ret = ParseInput(initializer_name_tensor, is_subgraph, onnx_graph); | |||
| @@ -890,13 +890,14 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||
| } | |||
| GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); | |||
| // 4. Parse Constant from graph. | |||
| // 3. Parse Constant from graph. | |||
| ret = ParseInitializer(onnx_graph, initializer_name_tensor); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "[Parse][Initializer] for onnx failed."); | |||
| return ret; | |||
| } | |||
| // 4. Get all output name form origin graph | |||
| ret = ParseOutput(onnx_graph); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "[Parse][Output] Parse output for onnx failed."); | |||
| @@ -904,11 +905,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||
| } | |||
| // 5. Update node name for node do not has name. | |||
| ret = UpdateAllNodeName(onnx_graph); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "[Update][Name] of all node for onnx failed."); | |||
| return ret; | |||
| } | |||
| UpdateAllNodeName(onnx_graph); | |||
| // 6 Precheck. | |||
| ret = Prechecker(onnx_graph); | |||
| @@ -950,19 +947,32 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||
| } | |||
| graph.SetInputs(input_ops); | |||
| // 10. Get output info and set outpus for subgraph | |||
| std::vector<std::pair<Operator, std::vector<size_t>>> output_ops; | |||
| ParserUtils::OutputMapping out_tensor_to_nodes; | |||
| ret = GetGraphOutputs(output_ops, out_tensor_to_nodes); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "[Get][Outputs] failed."); | |||
| return ret; | |||
| } | |||
| // root graph needn't set outputs. | |||
| if(is_subgraph) { | |||
| std::vector<std::pair<Operator, std::vector<size_t>>> output_ops; | |||
| ret = GetGraphOutputs(output_ops); | |||
| graph.SetOutputs(output_ops); | |||
| } | |||
| // 11. Expand node to graph if need | |||
| ParserUtils::OutputMapping final_output_nodes; | |||
| GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph, final_output_nodes)); | |||
| // 12. Set outputs info in ParserContext for root graph | |||
| if (!is_subgraph) { | |||
| ret = SetOutputsInfo(final_output_nodes, out_tensor_to_nodes); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "[Get][Outputs] failed."); | |||
| GELOGE(ret, "[Set][OutputsInfo] Graph:%s.", graph.GetName().c_str()); | |||
| return ret; | |||
| } | |||
| graph.SetOutputs(output_ops); | |||
| } | |||
| GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); | |||
| GELOGI("Onnx model parser success."); | |||
| return SUCCESS; | |||
| } | |||
| @@ -1048,6 +1058,50 @@ void OnnxModelParser::UpdateDataFormat(ge::Graph &graph) { | |||
| return; | |||
| } | |||
| Status OnnxModelParser::SetOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes, | |||
| const ParserUtils::OutputMapping &tensor_to_nodes) { | |||
| auto &user_specified_nodes = ge::GetParserContext().user_out_nodes; | |||
| if (!user_specified_nodes.empty()) { | |||
| GELOGI("User specified the output nodes with node_name and index."); | |||
| for (auto &output_node_info : user_specified_nodes) { | |||
| ParserUtils::UpdateOutputNodeInfo(final_output_nodes, output_node_info); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| auto final_tensor_to_nodes = tensor_to_nodes; | |||
| ParserUtils::UpdateOutputCtx(final_output_nodes, final_tensor_to_nodes); | |||
| auto &user_specified_tensors = ge::GetParserContext().user_out_tensors; | |||
| auto &output_tensor_names = ge::GetParserContext().out_tensor_names; | |||
| output_tensor_names.clear(); | |||
| if (!user_specified_tensors.empty()) { | |||
| for (auto &tensor_name : user_specified_tensors) { | |||
| auto iter = final_tensor_to_nodes.find(tensor_name); | |||
| if (iter != final_tensor_to_nodes.end()) { | |||
| user_specified_nodes.emplace_back(iter->second); | |||
| output_tensor_names.emplace_back(tensor_name); | |||
| GELOGI("[UserSpecified]Add network output node[%s], index[%d], tensor name[%s].", | |||
| iter->second.first.c_str(), iter->second.second, tensor_name.c_str()); | |||
| } else { | |||
| REPORT_INNER_ERROR("E19999", "User specified tensor[%s] is not output of graph.", tensor_name.c_str()); | |||
| GELOGE(FAILED, "[Set][OutputsInfo]User specified tensor[%s] is not output of graph.", tensor_name.c_str()); | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| // for default output | |||
| auto &default_out_nodes = ge::GetParserContext().default_out_nodes; | |||
| for (auto &tensor_name : output_node_names_) { | |||
| auto &output_node_info = final_tensor_to_nodes[tensor_name]; | |||
| default_out_nodes.emplace_back(output_node_info); | |||
| output_tensor_names.emplace_back(tensor_name); | |||
| GELOGI("[Default]Add network output node[%s], index[%d], tensor name[%s].", | |||
| output_node_info.first.c_str(), output_node_info.second, tensor_name.c_str()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace domi | |||
| namespace domi { | |||
| @@ -38,6 +38,7 @@ | |||
| #include "omg/parser/model_parser.h" | |||
| #include "omg/parser/op_parser.h" | |||
| #include "omg/parser/weights_parser.h" | |||
| #include "common/parser_utils.h" | |||
| #include "proto/onnx/ge_onnx.pb.h" | |||
| namespace ge { | |||
| @@ -80,7 +81,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||
| Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||
| std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor); | |||
| Status UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph); | |||
| void UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph); | |||
| Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); | |||
| @@ -94,7 +95,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||
| Status GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops); | |||
| Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &outputs); | |||
| Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &outputs, | |||
| ParserUtils::OutputMapping &out_tensor_to_nodes); | |||
| Status Prechecker(ge::onnx::GraphProto &onnx_graph); | |||
| @@ -115,6 +117,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||
| Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, | |||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph); | |||
| Status SetOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes, | |||
| const ParserUtils::OutputMapping &tensor_to_nodes); | |||
| std::map<std::string, std::string> ori_to_om_type_; | |||
| std::map<std::string, int64_t> domain_verseion_; | |||
| @@ -1494,7 +1494,9 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||
| GE_RETURN_IF_ERROR(AddEdges(graph)); | |||
| Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph); | |||
| GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph)); | |||
| ParserUtils::OutputMapping final_output_nodes; | |||
| GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph, final_output_nodes)); | |||
| GE_RETURN_IF_ERROR(UpdateOutputsInfo(final_output_nodes)); | |||
| GE_RETURN_IF_ERROR(RemoveIsolateNode(graph)); | |||
| GE_RETURN_IF_ERROR(CheckAndUpdateInputDesc(graph)); | |||
| GE_RETURN_IF_ERROR(graph->TopologicalSorting()); | |||
| @@ -2304,7 +2306,9 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||
| ret = AddEdges(graph); | |||
| Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph); | |||
| GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph)); | |||
| ParserUtils::OutputMapping final_output_nodes; | |||
| GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph, final_output_nodes)); | |||
| GE_RETURN_IF_ERROR(UpdateOutputsInfo(final_output_nodes)); | |||
| DeleteFuisonNodeDef(); | |||
| GE_CHK_STATUS_EXEC(ret, return ret, "AddEdges failed"); | |||
| @@ -4020,6 +4024,16 @@ Status TensorFlowModelParser::CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compu | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status TensorFlowModelParser::UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes) { | |||
| auto &user_specified_nodes = ge::GetParserContext().user_out_nodes; | |||
| if (!user_specified_nodes.empty()) { | |||
| for (auto &output_node_info : user_specified_nodes) { | |||
| ParserUtils::UpdateOutputNodeInfo(final_output_nodes, output_node_info); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| namespace domi { | |||
| @@ -44,6 +44,7 @@ | |||
| #include "proto/tensorflow/graph_library.pb.h" | |||
| #include "external/register/scope/scope_fusion_pass_register.h" | |||
| #include "scope/scope_pass_manager.h" | |||
| #include "common/parser_utils.h" | |||
| using ge::ScopePassManager; | |||
| using domi::tensorflow::GraphDef; | |||
| @@ -647,6 +648,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||
| Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr<OpParser> &op_parser); | |||
| Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph); | |||
| static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes); | |||
| /** | |||
| * save <node_name, node_def> | |||
| @@ -275,10 +275,10 @@ INT32 mmGetPid() | |||
| } | |||
| INT32 mmDup2(INT32 oldFd, INT32 newFd) { | |||
| return -1; | |||
| return 0; | |||
| } | |||
| INT32 mmDup(INT32 fd) { | |||
| return -1; | |||
| return 0; | |||
| } | |||
| @@ -296,6 +296,7 @@ include_directories(${PARSER_DIR}) | |||
| include_directories(${PARSER_DIR}/inc) | |||
| include_directories(${PARSER_DIR}/parser) | |||
| include_directories(${PARSER_DIR}/parser/onnx) | |||
| include_directories(${PARSER_DIR}/tests) | |||
| include_directories(${PARSER_DIR}/metadef/inc) | |||
| include_directories(${PARSER_DIR}/metadef/inc/external) | |||
| include_directories(${PARSER_DIR}/metadef/inc/register) | |||
| @@ -306,7 +307,10 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) | |||
| set(PARSER_UT_FILES | |||
| "parser_ut_utils.cc" | |||
| "testcase/common/acl_graph_parser_unittest.cc" | |||
| "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" | |||
| "testcase/caffe_parser_testcase/caffe_parser_unittest.cc" | |||
| "testcase/onnx_parser_testcase/message2operator_unittest.cc" | |||
| "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" | |||
| "testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc" | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ut/parser/parser_ut_utils.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| namespace ge { | |||
| void ParerUTestsUtils::ClearParserInnerCtx() { | |||
| ge::GetParserContext().input_nodes_format_map.clear(); | |||
| ge::GetParserContext().output_formats.clear(); | |||
| ge::GetParserContext().user_input_dims.clear(); | |||
| ge::GetParserContext().input_dims.clear(); | |||
| ge::GetParserContext().op_conf_map.clear(); | |||
| ge::GetParserContext().user_out_nodes.clear(); | |||
| ge::GetParserContext().default_out_nodes.clear(); | |||
| ge::GetParserContext().out_nodes_map.clear(); | |||
| ge::GetParserContext().user_out_tensors.clear(); | |||
| ge::GetParserContext().net_out_nodes.clear(); | |||
| ge::GetParserContext().out_tensor_names.clear(); | |||
| ge::GetParserContext().data_tensor_names.clear(); | |||
| ge::GetParserContext().is_dynamic_input = false; | |||
| ge::GetParserContext().train_flag = false; | |||
| ge::GetParserContext().format = domi::DOMI_TENSOR_ND; | |||
| ge::GetParserContext().type = domi::FRAMEWORK_RESERVED; | |||
| ge::GetParserContext().run_mode = GEN_OM_MODEL; | |||
| ge::GetParserContext().custom_proto_path = ""; | |||
| ge::GetParserContext().caffe_proto_path = ""; | |||
| ge::GetParserContext().enable_scope_fusion_passes = ""; | |||
| GELOGI("Clear parser inner context successfully."); | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_PARSER_TESTS_UT_PARSER_H_ | |||
| #define GE_PARSER_TESTS_UT_PARSER_H_ | |||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||
| namespace ge { | |||
| class ParerUTestsUtils { | |||
| public: | |||
| static void ClearParserInnerCtx(); | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_PARSER_TESTS_UT_PARSER_H_ | |||
| @@ -0,0 +1,14 @@ | |||
| name: "TestAbs" | |||
| layer { | |||
| name: "data" | |||
| type: "Input" | |||
| top: "data" | |||
| input_param { shape: { dim: 64 dim: 1 dim: 28 dim: 28 } } | |||
| } | |||
| layer { | |||
| name: "abs" | |||
| type: "AbsVal" | |||
| bottom: "data" | |||
| top: "abs_out" | |||
| } | |||
| @@ -0,0 +1,158 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <gtest/gtest.h> | |||
| #include <iostream> | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "graph/operator_reg.h" | |||
| #include "external/graph/types.h" | |||
| #include "register/op_registry.h" | |||
| #include "parser/common/register_tbe.h" | |||
| #include "framework/omg/parser/model_parser.h" | |||
| #include "framework/omg/parser/parser_factory.h" | |||
| #include "external/parser/caffe_parser.h" | |||
| #include "ut/parser/parser_ut_utils.h" | |||
| #include "external/ge/ge_api_types.h" | |||
| namespace ge { | |||
| class UtestCaffeParser : public testing::Test { | |||
| protected: | |||
| void SetUp() { | |||
| ParerUTestsUtils::ClearParserInnerCtx(); | |||
| RegisterCustomOp(); | |||
| } | |||
| void TearDown() {} | |||
| public: | |||
| void RegisterCustomOp(); | |||
| }; | |||
| static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||
| return SUCCESS; | |||
| } | |||
| void UtestCaffeParser::RegisterCustomOp() { | |||
| REGISTER_CUSTOM_OP("Data") | |||
| .FrameworkType(domi::CAFFE) | |||
| .OriginOpType("Input") | |||
| .ParseParamsFn(ParseParams); | |||
| REGISTER_CUSTOM_OP("Abs") | |||
| .FrameworkType(domi::CAFFE) | |||
| .OriginOpType("AbsVal") | |||
| .ParseParamsFn(ParseParams); | |||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||
| for (auto reg_data : reg_datas) { | |||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
| domi::OpRegistry::Instance()->Register(reg_data); | |||
| } | |||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
| } | |||
| namespace { | |||
| REG_OP(Data) | |||
| .INPUT(x, TensorType::ALL()) | |||
| .OUTPUT(y, TensorType::ALL()) | |||
| .ATTR(index, Int, 0) | |||
| .OP_END_FACTORY_REG(Data) | |||
| REG_OP(Abs) | |||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | |||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | |||
| .OP_END_FACTORY_REG(Abs) | |||
| } | |||
| TEST_F(UtestCaffeParser, caffe_parser_user_output_with_name_and_index) { | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/caffe_model/caffe_abs.pbtxt"; | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); | |||
| ASSERT_NE(model_parser, nullptr); | |||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph"); | |||
| ASSERT_NE(compute_graph, nullptr); | |||
| ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||
| ge::GetParserContext().user_out_nodes.push_back({"abs", 0}); | |||
| auto ret = model_parser->Parse(model_file.c_str(), graph); | |||
| ASSERT_EQ(ret, GRAPH_SUCCESS); | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| std::map<AscendString, AscendString> parser_params; | |||
| auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | |||
| ASSERT_EQ(status, SUCCESS); | |||
| auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); | |||
| ASSERT_EQ(output_nodes_info.size(), 1); | |||
| EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs"); | |||
| EXPECT_EQ((output_nodes_info.at(0).second), 0); | |||
| auto &net_out_name = ge::GetParserContext().net_out_nodes; | |||
| ASSERT_EQ(net_out_name.size(), 1); | |||
| EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); | |||
| } | |||
| TEST_F(UtestCaffeParser, caffe_parser_user_output_with_top_name) { | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/caffe_model/caffe_abs.pbtxt"; | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); | |||
| ASSERT_NE(model_parser, nullptr); | |||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph"); | |||
| ASSERT_NE(compute_graph, nullptr); | |||
| ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||
| ge::GetParserContext().user_out_tensors.push_back("abs_out"); | |||
| auto ret = model_parser->Parse(model_file.c_str(), graph); | |||
| ASSERT_EQ(ret, GRAPH_SUCCESS); | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| std::map<AscendString, AscendString> parser_params; | |||
| auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | |||
| ASSERT_EQ(status, SUCCESS); | |||
| auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); | |||
| ASSERT_EQ(output_nodes_info.size(), 1); | |||
| EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs"); | |||
| EXPECT_EQ((output_nodes_info.at(0).second), 0); | |||
| auto &net_out_name = ge::GetParserContext().net_out_nodes; | |||
| ASSERT_EQ(net_out_name.size(), 1); | |||
| EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); | |||
| } | |||
| TEST_F(UtestCaffeParser, caffe_parser_user_output_with_default) { | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/caffe_model/caffe_abs.pbtxt"; | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); | |||
| ASSERT_NE(model_parser, nullptr); | |||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph"); | |||
| ASSERT_NE(compute_graph, nullptr); | |||
| ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||
| auto ret = model_parser->Parse(model_file.c_str(), graph); | |||
| ASSERT_EQ(ret, GRAPH_SUCCESS); | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| std::map<AscendString, AscendString> parser_params; | |||
| auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | |||
| ASSERT_EQ(status, SUCCESS); | |||
| auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); | |||
| ASSERT_EQ(output_nodes_info.size(), 1); | |||
| EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs"); | |||
| EXPECT_EQ((output_nodes_info.at(0).second), 0); | |||
| auto &net_out_name = ge::GetParserContext().net_out_nodes; | |||
| ASSERT_EQ(net_out_name.size(), 1); | |||
| EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <gtest/gtest.h> | |||
| #include <iostream> | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "graph/operator_reg.h" | |||
| #include "external/graph/types.h" | |||
| #include "register/op_registry.h" | |||
| #include "parser/common/register_tbe.h" | |||
| #include "external/parser/onnx_parser.h" | |||
| #include "ut/parser/parser_ut_utils.h" | |||
| #include "external/ge/ge_api_types.h" | |||
| namespace ge { | |||
| class UtestAclGraphParser : public testing::Test { | |||
| protected: | |||
| void SetUp() { | |||
| } | |||
| void TearDown() {} | |||
| }; | |||
| TEST_F(UtestAclGraphParser, test_parse_acl_output_nodes) { | |||
| AclGrphParseUtil acl_graph_parse_util; | |||
| string graph_name; | |||
| // case 1: Normal with 'node and index' | |||
| ParerUTestsUtils::ClearParserInnerCtx(); | |||
| GetParserContext().type = domi::ONNX; | |||
| std::map<AscendString, AscendString> out_nodes_with_node_and_index = { | |||
| {AscendString(ge::ir_option::OUT_NODES), AscendString("Out1:0;Out2:1")}}; | |||
| ParerUTestsUtils::ClearParserInnerCtx(); | |||
| auto ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_with_node_and_index, graph_name); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 2); | |||
| EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 2); | |||
| EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 0); | |||
| // case 2: Normal with 'tensor name' | |||
| ParerUTestsUtils::ClearParserInnerCtx(); | |||
| GetParserContext().type = domi::ONNX; | |||
| std::map<AscendString, AscendString> out_nodes_with_tensor_name = { | |||
| {AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out_tensor_2")}}; | |||
| ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_with_tensor_name, graph_name); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0); | |||
| EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0); | |||
| EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 2); | |||
| // case 3: Failed with 'node and index' before 'tensor name' | |||
| ParerUTestsUtils::ClearParserInnerCtx(); | |||
| GetParserContext().type = domi::ONNX; | |||
| std::map<AscendString, AscendString> out_nodes_mode_mixex_pre = { | |||
| {AscendString(ge::ir_option::OUT_NODES), AscendString("Out1:0;Out2:1;Out_tensor_1;Out_tensor_2")}}; | |||
| ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_pre, graph_name); | |||
| ASSERT_EQ(ret, PARAM_INVALID); | |||
| EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 2); | |||
| EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 2); | |||
| EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 0); | |||
| // case 4: Failed with 'node and index' inserted in 'tensor name' | |||
| ParerUTestsUtils::ClearParserInnerCtx(); | |||
| GetParserContext().type = domi::ONNX; | |||
| std::map<AscendString, AscendString> out_nodes_mode_mixex_mid = { | |||
| {AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out1:0;Out2:1;Out_tensor_2")}}; | |||
| ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_mid, graph_name); | |||
| ASSERT_EQ(ret, PARAM_INVALID); | |||
| EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0); | |||
| EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0); | |||
| EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 1); | |||
| // case 5: Failed with 'node and index' after 'tensor name' | |||
| ParerUTestsUtils::ClearParserInnerCtx(); | |||
| GetParserContext().type = domi::ONNX; | |||
| std::map<AscendString, AscendString> out_nodes_mode_mixex_post = { | |||
| {AscendString(ge::ir_option::OUT_NODES), AscendString("Out_tensor_1;Out_tensor_2;Out1:0;Out2:1")}}; | |||
| ret = acl_graph_parse_util.ParseParamsBeforeGraph(out_nodes_mode_mixex_post, graph_name); | |||
| ASSERT_EQ(ret, PARAM_INVALID); | |||
| EXPECT_EQ(ge::GetParserContext().user_out_nodes.size(), 0); | |||
| EXPECT_EQ(ge::GetParserContext().out_nodes_map.size(), 0); | |||
| EXPECT_EQ(ge::GetParserContext().user_out_tensors.size(), 2); | |||
| } | |||
| } // namespace ge | |||
| @@ -22,12 +22,16 @@ | |||
| #include "register/op_registry.h" | |||
| #include "parser/common/register_tbe.h" | |||
| #include "external/parser/onnx_parser.h" | |||
| #include "ut/parser/parser_ut_utils.h" | |||
| #include "external/ge/ge_api_types.h" | |||
| namespace ge { | |||
| class UtestOnnxParser : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void SetUp() { | |||
| ParerUTestsUtils::ClearParserInnerCtx(); | |||
| RegisterCustomOp(); | |||
| } | |||
| void TearDown() {} | |||
| @@ -152,28 +156,81 @@ REG_OP(Identity) | |||
| .OP_END_FACTORY_REG(Identity) | |||
| } | |||
| TEST_F(UtestOnnxParser, onnx_parser_success) { | |||
| RegisterCustomOp(); | |||
| TEST_F(UtestOnnxParser, onnx_parser_if_node) { | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/onnx_model/if.onnx"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||
| } | |||
| TEST_F(UtestOnnxParser, onnx_parser_user_output_with_name_and_index) { | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/onnx_model/conv2d.onnx"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| parser_params.insert({AscendString(ge::ir_option::OUT_NODES), AscendString("Conv_0:0")}); | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||
| EXPECT_EQ(ret, domi::SUCCESS); | |||
| ASSERT_EQ(ret, GRAPH_SUCCESS); | |||
| ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
| auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); | |||
| ASSERT_EQ(output_nodes_info.size(), 1); | |||
| EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0"); | |||
| EXPECT_EQ((output_nodes_info.at(0).second), 0); | |||
| auto &net_out_name = ge::GetParserContext().net_out_nodes; | |||
| ASSERT_EQ(net_out_name.size(), 1); | |||
| EXPECT_EQ(net_out_name.at(0), "Conv_0:0"); | |||
| } | |||
| TEST_F(UtestOnnxParser, onnx_parser_if_node) { | |||
| RegisterCustomOp(); | |||
| TEST_F(UtestOnnxParser, onnx_parser_user_output_with_tensor) { | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/onnx_model/conv2d.onnx"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| parser_params.insert({AscendString(ge::ir_option::OUT_NODES), AscendString("y")}); | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||
| ASSERT_EQ(ret, GRAPH_SUCCESS); | |||
| ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
| auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); | |||
| ASSERT_EQ(output_nodes_info.size(), 1); | |||
| EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0"); | |||
| EXPECT_EQ((output_nodes_info.at(0).second), 0); | |||
| auto &net_out_name = ge::GetParserContext().net_out_nodes; | |||
| ASSERT_EQ(net_out_name.size(), 1); | |||
| EXPECT_EQ(net_out_name.at(0), "Conv_0:0:y"); | |||
| } | |||
| TEST_F(UtestOnnxParser, onnx_parser_user_output_with_default) { | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/onnx_model/if.onnx"; | |||
| std::string model_file = case_dir + "/onnx_model/conv2d.onnx"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||
| ASSERT_EQ(ret, GRAPH_SUCCESS); | |||
| ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
| auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); | |||
| ASSERT_EQ(output_nodes_info.size(), 1); | |||
| EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0"); | |||
| EXPECT_EQ((output_nodes_info.at(0).second), 0); | |||
| auto &net_out_name = ge::GetParserContext().net_out_nodes; | |||
| ASSERT_EQ(net_out_name.size(), 1); | |||
| EXPECT_EQ(net_out_name.at(0), "Conv_0:0:y"); | |||
| } | |||
| TEST_F(UtestOnnxParser, onnx_parser_user_output_with_tensor_failed) { | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/onnx_model/conv2d.onnx"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| parser_params.insert({AscendString(ge::ir_option::OUT_NODES), AscendString("not_exist_output")}); | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||
| EXPECT_EQ(ret, domi::SUCCESS); | |||
| EXPECT_EQ(ret, FAILED); | |||
| } | |||
| } // namespace ge | |||
| @@ -24,12 +24,14 @@ | |||
| #include "register/op_registry.h" | |||
| #include "parser/common/register_tbe.h" | |||
| #include "external/parser/tensorflow_parser.h" | |||
| #include "ut/parser/parser_ut_utils.h" | |||
| namespace ge { | |||
| class UtestTensorflowParser : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void SetUp() { | |||
| ParerUTestsUtils::ClearParserInnerCtx(); | |||
| } | |||
| void TearDown() {} | |||