Merge pull request !369 from Fu Jingguo/masterpull/370/head
| @@ -54,6 +54,7 @@ | |||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| #include "register/scope/scope_graph_impl.h" | #include "register/scope/scope_graph_impl.h" | ||||
| #include "register/scope/scope_pass_registry_impl.h" | #include "register/scope/scope_pass_registry_impl.h" | ||||
| #include "parser/common/auto_mapping_subgraph_io_index_func.h" | |||||
| using ge::const_op_update_vec; | using ge::const_op_update_vec; | ||||
| using ge::OpParserFactory; | using ge::OpParserFactory; | ||||
| @@ -210,6 +211,7 @@ const std::vector<std::string> kMakeOperatorNotByIr = {ge::parser::ARG, ge::pars | |||||
| const char *const kDpop = "DPOP"; | const char *const kDpop = "DPOP"; | ||||
| const char *const kFuncDefLibraryFilePath = "graph_def_library.pbtxt"; | const char *const kFuncDefLibraryFilePath = "graph_def_library.pbtxt"; | ||||
| const char *const kAttrNameIsScopeInnerNode = "_is_scope_inner_node"; | const char *const kAttrNameIsScopeInnerNode = "_is_scope_inner_node"; | ||||
| const char *const kExternalModel = "_external_model"; | |||||
| struct ParseArg { | struct ParseArg { | ||||
| const google::protobuf::Message *proto; | const google::protobuf::Message *proto; | ||||
| std::string function_name; | std::string function_name; | ||||
| @@ -302,6 +304,52 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) { | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, ComputeGraphPtr &root_graph){ | |||||
| // Inner function, input params have been checked by caller | |||||
| Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(graph, | |||||
| [](int in, int &out)->Status { | |||||
| out = in; | |||||
| return SUCCESS; | |||||
| }, | |||||
| [](int in, int &out)->Status { | |||||
| out = in; | |||||
| return SUCCESS; | |||||
| }); | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]node:%s, sub graph name:%s.", | |||||
| node->GetName().c_str(), graph.GetName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to map sub graph input and output, node:%s, sub graph name:%s.", | |||||
| node->GetName().c_str(), graph.GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| // Inner function, GetOpDesc has been checked by caller | |||||
| (void)node->GetOpDesc()->AddSubgraphName("f"); | |||||
| auto ret = NodeUtils::SetSubgraph(*node, 0, compute_graph); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "[Set][Subgraph]Node:%s, sub graph name:%s.", | |||||
| node->GetName().c_str(), compute_graph->GetName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to set sub graph, node: %s, sub graph name: %s.", | |||||
| node->GetName().c_str(), compute_graph->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| for (const auto &sub_graph : compute_graph->GetAllSubgraphs()) { | |||||
| ret = root_graph->AddSubgraph(sub_graph); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "[Add][Subgraph]Node:%s, sub graph name:%s, sub sub graph name:%s.", | |||||
| node->GetName().c_str(), compute_graph->GetName().c_str(), sub_graph->GetName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to add sub graph to root graph, node:%s, sub graph name:%s.", | |||||
| node->GetName().c_str(), sub_graph->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| compute_graph->RemoveSubgraph(sub_graph); | |||||
| GELOGD("Add subgraph[%s] to root graph[%s].", sub_graph->GetName().c_str(), root_graph->GetName().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| /** | /** | ||||
| @@ -2397,6 +2445,11 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Mes | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| } | } | ||||
| auto add_ret = AddExternalGraph(root_graph); | |||||
| if (add_ret != SUCCESS) { | |||||
| GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str()); | |||||
| return add_ret; | |||||
| } | |||||
| PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); | PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -2460,6 +2513,11 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_pro | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| } | } | ||||
| auto add_ret = AddExternalGraph(root_graph); | |||||
| if (add_ret != SUCCESS) { | |||||
| GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str()); | |||||
| return add_ret; | |||||
| } | |||||
| PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); | PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -4034,6 +4092,34 @@ Status TensorFlowModelParser::UpdateOutputsInfo(const ParserUtils::OutputMapping | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowModelParser::AddExternalGraph(ComputeGraphPtr &root_graph) { | |||||
| GE_CHECK_NOTNULL(root_graph); | |||||
| for (const NodePtr &node : root_graph->GetAllNodes()) { | |||||
| if (node == nullptr || node->GetOpDesc() == nullptr) { | |||||
| continue; | |||||
| } | |||||
| std::string model_data; | |||||
| if (AttrUtils::GetStr(node->GetOpDesc(), kExternalModel, model_data) && !model_data.empty()) { | |||||
| ge::Model model; | |||||
| auto load_ret = ge::Model::Load(reinterpret_cast<const uint8_t *>(model_data.data()), model_data.size(), model); | |||||
| if (load_ret != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "[Parse][ExternalModel]Node:%s.", node->GetName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to parse external model, node:%s.", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| Graph graph = model.GetGraph(); | |||||
| GELOGD("Get subgraph[%s] from model[%s].", graph.GetName().c_str(), node->GetName().c_str()); | |||||
| Status ret = MappingAndAddSubGraph(node, graph, root_graph); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]Node:%s.", node->GetName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to map and add sub graph, node:%s.", node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| namespace domi { | namespace domi { | ||||
| @@ -649,6 +649,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr<OpParser> &op_parser); | Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr<OpParser> &op_parser); | ||||
| Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph); | Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph); | ||||
| static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes); | static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes); | ||||
| static Status AddExternalGraph(ComputeGraphPtr &root_graph); | |||||
| /** | /** | ||||
| * save <node_name, node_def> | * save <node_name, node_def> | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "ut/parser/parser_ut_utils.h" | #include "ut/parser/parser_ut_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/utils/graph_utils.h" | |||||
| namespace ge { | namespace ge { | ||||
| void ParerUTestsUtils::ClearParserInnerCtx() { | void ParerUTestsUtils::ClearParserInnerCtx() { | ||||
| @@ -41,4 +42,29 @@ void ParerUTestsUtils::ClearParserInnerCtx() { | |||||
| ge::GetParserContext().enable_scope_fusion_passes = ""; | ge::GetParserContext().enable_scope_fusion_passes = ""; | ||||
| GELOGI("Clear parser inner context successfully."); | GELOGI("Clear parser inner context successfully."); | ||||
| } | } | ||||
| namespace ut { | |||||
| NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, | |||||
| DataType data_type, std::vector<int64_t> shape) { | |||||
| auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||||
| tensor_desc->SetShape(GeShape(std::move(shape))); | |||||
| tensor_desc->SetFormat(format); | |||||
| tensor_desc->SetDataType(data_type); | |||||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (int i = 0; i < in_cnt; ++i) { | |||||
| op_desc->AddInputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| for (int i = 0; i < out_cnt; ++i) { | |||||
| op_desc->AddOutputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| return graph_->AddNode(op_desc); | |||||
| } | |||||
| void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) { | |||||
| GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||||
| } | |||||
| void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) { | |||||
| GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||||
| } | |||||
| } // namespace ut | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -18,12 +18,31 @@ | |||||
| #define GE_PARSER_TESTS_UT_PARSER_H_ | #define GE_PARSER_TESTS_UT_PARSER_H_ | ||||
| #include "framework/omg/parser/parser_inner_ctx.h" | #include "framework/omg/parser/parser_inner_ctx.h" | ||||
| #include "graph/compute_graph.h" | |||||
| namespace ge { | namespace ge { | ||||
| class ParerUTestsUtils { | class ParerUTestsUtils { | ||||
| public: | public: | ||||
| static void ClearParserInnerCtx(); | static void ClearParserInnerCtx(); | ||||
| }; | }; | ||||
| namespace ut { | |||||
| class GraphBuilder { | |||||
| public: | |||||
| explicit GraphBuilder(const std::string &name) { graph_ = std::make_shared<ComputeGraph>(name); } | |||||
| NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||||
| Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||||
| std::vector<int64_t> shape = {1, 1, 224, 224}); | |||||
| void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx); | |||||
| void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node); | |||||
| ComputeGraphPtr GetGraph() { | |||||
| graph_->TopologicalSorting(); | |||||
| return graph_; | |||||
| } | |||||
| private: | |||||
| ComputeGraphPtr graph_; | |||||
| }; | |||||
| } // namespace ut | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_PARSER_TESTS_UT_PARSER_H_ | #endif // GE_PARSER_TESTS_UT_PARSER_H_ | ||||
| @@ -17,7 +17,11 @@ | |||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #define private public | |||||
| #define protected public | |||||
| #include "parser/tensorflow/tensorflow_parser.h" | #include "parser/tensorflow/tensorflow_parser.h" | ||||
| #undef protected | |||||
| #undef private | |||||
| #include "framework/omg/parser/parser_factory.h" | #include "framework/omg/parser/parser_factory.h" | ||||
| #include "graph/operator_reg.h" | #include "graph/operator_reg.h" | ||||
| #include "external/graph/types.h" | #include "external/graph/types.h" | ||||
| @@ -25,6 +29,7 @@ | |||||
| #include "parser/common/register_tbe.h" | #include "parser/common/register_tbe.h" | ||||
| #include "external/parser/tensorflow_parser.h" | #include "external/parser/tensorflow_parser.h" | ||||
| #include "ut/parser/parser_ut_utils.h" | #include "ut/parser/parser_ut_utils.h" | ||||
| #include "graph/model.h" | |||||
| namespace ge { | namespace ge { | ||||
| class UtestTensorflowParser : public testing::Test { | class UtestTensorflowParser : public testing::Test { | ||||
| @@ -119,4 +124,95 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto3) { | |||||
| EXPECT_EQ(ret, ge::SUCCESS); | EXPECT_EQ(ret, ge::SUCCESS); | ||||
| } | } | ||||
| TEST_F(UtestTensorflowParser, tensorflow_parser_with_external_graph) { | |||||
| auto make_graph = [](const string &name) { | |||||
| auto builder = ut::GraphBuilder(name); | |||||
| auto data1 = builder.AddNode(name + "_input1", "Data", 1, 1); | |||||
| auto data2 = builder.AddNode(name + "_input2", "Data", 1, 1); | |||||
| auto add = builder.AddNode(name + "_add", "Add", 2, 1); | |||||
| auto net_output = builder.AddNode(name + "_net_output", "NetOutput", 1, 1); | |||||
| builder.AddDataEdge(data1, 0, add, 0); | |||||
| builder.AddDataEdge(data2, 0, add, 1); | |||||
| builder.AddDataEdge(add, 0, net_output, 0); | |||||
| return builder.GetGraph(); | |||||
| }; | |||||
| // 1. Create root graph | |||||
| ComputeGraphPtr root_graph = make_graph("root_graph"); | |||||
| // 2. Create ONNX sub graph | |||||
| // 2.1 Sub graph of onnx graph | |||||
| ge::ComputeGraphPtr sub_sub_graph = ge::parser::MakeShared<ge::ComputeGraph>("sub_sub"); | |||||
| // 2.2 ONNX graph | |||||
| ComputeGraphPtr sub_graph = make_graph("sub_graph"); | |||||
| auto add = sub_graph->FindNode("sub_graph_add"); | |||||
| ASSERT_NE(add, nullptr); | |||||
| add->GetOpDesc()->AddSubgraphName("sub_sub_graph"); | |||||
| add->GetOpDesc()->SetSubgraphInstanceName(0, sub_sub_graph->GetName()); | |||||
| sub_graph->AddSubGraph(sub_sub_graph); | |||||
| auto input1 = sub_graph->FindNode("sub_graph_input1"); | |||||
| ASSERT_NE(input1, nullptr); | |||||
| AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
| auto input2 = sub_graph->FindNode("sub_graph_input2"); | |||||
| ASSERT_NE(input2, nullptr); | |||||
| AttrUtils::SetInt(input2->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
| // 3. Serialize ONNX graph to string | |||||
| // 3.1 normal | |||||
| ge::Model model("model", ""); | |||||
| model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph)); | |||||
| Buffer buffer; | |||||
| graphStatus save_ret = model.Save(buffer, false); | |||||
| ASSERT_EQ(save_ret, GRAPH_SUCCESS); | |||||
| std::string external_graph(reinterpret_cast<const char *>(buffer.GetData()), buffer.GetSize()); | |||||
| // model will failed | |||||
| input1->GetOpDesc()->DelAttr(ATTR_NAME_INDEX); | |||||
| ge::Model model_will_fail("model_will_fail", ""); | |||||
| model_will_fail.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph)); | |||||
| Buffer buffer_fail; | |||||
| save_ret = model_will_fail.Save(buffer_fail, false); | |||||
| ASSERT_EQ(save_ret, GRAPH_SUCCESS); | |||||
| std::string external_graph_fail(reinterpret_cast<const char *>(buffer_fail.GetData()), buffer_fail.GetSize()); | |||||
| // 4. Set string to function node | |||||
| auto root_add = root_graph->FindNode("root_graph_add"); | |||||
| ASSERT_NE(root_add, nullptr); | |||||
| AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph); | |||||
| auto root_input1 = root_graph->FindNode("root_graph_input1"); | |||||
| ASSERT_NE(root_input1, nullptr); | |||||
| AttrUtils::SetInt(root_input1->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
| auto root_input2 = root_graph->FindNode("root_graph_input2"); | |||||
| ASSERT_NE(root_input2, nullptr); | |||||
| AttrUtils::SetInt(root_input2->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
| // 5. Run test (normal) | |||||
| auto ret = TensorFlowModelParser::AddExternalGraph(root_graph); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| EXPECT_EQ(root_graph->GetAllSubgraphs().size(), 2); | |||||
| EXPECT_EQ(sub_graph->GetAllSubgraphs().size(), 1); | |||||
| EXPECT_NE(root_graph->GetSubgraph(sub_graph->GetName()), nullptr); | |||||
| EXPECT_EQ(root_graph->GetSubgraph(sub_graph->GetName())->GetAllSubgraphs().size(), 0); | |||||
| // 6. Run test (failed) | |||||
| // 6.1 Failed to load model | |||||
| AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", "dummy string"); | |||||
| ret = TensorFlowModelParser::AddExternalGraph(root_graph); | |||||
| EXPECT_EQ(ret, INTERNAL_ERROR); | |||||
| // 6.2 Failed to map sub graph | |||||
| AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph_fail); | |||||
| ret = TensorFlowModelParser::AddExternalGraph(root_graph); | |||||
| EXPECT_EQ(ret, INTERNAL_ERROR); | |||||
| // 6.3 Failed to set sub graph to node | |||||
| AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph); | |||||
| root_add->SetOwnerComputeGraph(nullptr); | |||||
| ret = TensorFlowModelParser::AddExternalGraph(root_graph); | |||||
| EXPECT_EQ(ret, INTERNAL_ERROR); | |||||
| // 6.4 Failed to add sub sub graph | |||||
| root_add->SetOwnerComputeGraph(nullptr); | |||||
| root_graph->RemoveSubGraph(sub_graph); | |||||
| ret = TensorFlowModelParser::AddExternalGraph(root_graph); | |||||
| EXPECT_EQ(ret, INTERNAL_ERROR); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||