diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index f188a97..640ab0a 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -54,6 +54,7 @@ #include "register/op_registry.h" #include "register/scope/scope_graph_impl.h" #include "register/scope/scope_pass_registry_impl.h" +#include "parser/common/auto_mapping_subgraph_io_index_func.h" using ge::const_op_update_vec; using ge::OpParserFactory; @@ -210,6 +211,7 @@ const std::vector kMakeOperatorNotByIr = {ge::parser::ARG, ge::pars const char *const kDpop = "DPOP"; const char *const kFuncDefLibraryFilePath = "graph_def_library.pbtxt"; const char *const kAttrNameIsScopeInnerNode = "_is_scope_inner_node"; +const char *const kExternalModel = "_external_model"; struct ParseArg { const google::protobuf::Message *proto; std::string function_name; @@ -302,6 +304,52 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) { } return SUCCESS; } + +Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, ComputeGraphPtr &root_graph){ + // Inner function, input params have been checked by caller + Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(graph, + [](int in, int &out)->Status { + out = in; + return SUCCESS; + }, + [](int in, int &out)->Status { + out = in; + return SUCCESS; + }); + if (status != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]node:%s, sub graph name:%s.", + node->GetName().c_str(), graph.GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to map sub graph input and output, node:%s, sub graph name:%s.", + node->GetName().c_str(), graph.GetName().c_str()); + return INTERNAL_ERROR; + } + + ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + // Inner function, GetOpDesc has been checked by caller + (void)node->GetOpDesc()->AddSubgraphName("f"); + auto ret = NodeUtils::SetSubgraph(*node, 0, compute_graph); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Set][Subgraph]Node:%s, sub graph name:%s.", + node->GetName().c_str(), compute_graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to set sub graph, node: %s, sub graph name: %s.", + node->GetName().c_str(), compute_graph->GetName().c_str()); + return INTERNAL_ERROR; + } + for (const auto &sub_graph : compute_graph->GetAllSubgraphs()) { + ret = root_graph->AddSubgraph(sub_graph); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Add][Subgraph]Node:%s, sub graph name:%s, sub sub graph name:%s.", + node->GetName().c_str(), compute_graph->GetName().c_str(), sub_graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to add sub graph to root graph, node:%s, sub graph name:%s.", + node->GetName().c_str(), sub_graph->GetName().c_str()); + return INTERNAL_ERROR; + } + compute_graph->RemoveSubgraph(sub_graph); + GELOGD("Add subgraph[%s] to root graph[%s].", sub_graph->GetName().c_str(), root_graph->GetName().c_str()); + } + return SUCCESS; +} } // namespace /** @@ -2397,6 +2445,11 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Mes return ret; } } + auto add_ret = AddExternalGraph(root_graph); + if (add_ret != SUCCESS) { + GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str()); + return add_ret; + } PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); return SUCCESS; } @@ -2460,6 +2513,11 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_pro return ret; } } + auto add_ret = AddExternalGraph(root_graph); + if (add_ret != SUCCESS) { + GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str()); + return add_ret; + } PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); return SUCCESS; } @@ -4034,6 +4092,34 @@ Status TensorFlowModelParser::UpdateOutputsInfo(const ParserUtils::OutputMapping } return SUCCESS; } + +Status TensorFlowModelParser::AddExternalGraph(ComputeGraphPtr &root_graph) { + GE_CHECK_NOTNULL(root_graph); + for (const NodePtr &node : root_graph->GetAllNodes()) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + continue; + } + std::string model_data; + if (AttrUtils::GetStr(node->GetOpDesc(), kExternalModel, model_data) && !model_data.empty()) { + ge::Model model; + auto load_ret = ge::Model::Load(reinterpret_cast(model_data.data()), model_data.size(), model); + if (load_ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Parse][ExternalModel]Node:%s.", node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to parse external model, node:%s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + Graph graph = model.GetGraph(); + GELOGD("Get subgraph[%s] from model[%s].", graph.GetName().c_str(), node->GetName().c_str()); + Status ret = MappingAndAddSubGraph(node, graph, root_graph); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]Node:%s.", node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to map and add sub graph, node:%s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + } + } + return SUCCESS; +} } // namespace ge namespace domi { diff --git a/parser/tensorflow/tensorflow_parser.h b/parser/tensorflow/tensorflow_parser.h index 94f5bb3..83bad49 100644 --- a/parser/tensorflow/tensorflow_parser.h +++ b/parser/tensorflow/tensorflow_parser.h @@ -649,6 +649,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr &op_parser); Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph); static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes); + static Status AddExternalGraph(ComputeGraphPtr &root_graph); /** * save diff --git a/tests/ut/parser/parser_ut_utils.cc b/tests/ut/parser/parser_ut_utils.cc index 202ca1d..d0556d4 100644 --- a/tests/ut/parser/parser_ut_utils.cc +++ b/tests/ut/parser/parser_ut_utils.cc @@ -16,6 +16,7 @@ #include "ut/parser/parser_ut_utils.h" #include "framework/common/debug/ge_log.h" +#include "graph/utils/graph_utils.h" namespace ge { void ParerUTestsUtils::ClearParserInnerCtx() { @@ -41,4 +42,29 @@ void ParerUTestsUtils::ClearParserInnerCtx() { ge::GetParserContext().enable_scope_fusion_passes = ""; GELOGI("Clear parser inner context successfully."); } +namespace ut { +NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, + DataType data_type, std::vector shape) { + auto tensor_desc = std::make_shared(); + tensor_desc->SetShape(GeShape(std::move(shape))); + tensor_desc->SetFormat(format); + tensor_desc->SetDataType(data_type); + + auto op_desc = std::make_shared(name, type); + for (int i = 0; i < in_cnt; ++i) { + op_desc->AddInputDesc(tensor_desc->Clone()); + } + for (int i = 0; i < out_cnt; ++i) { + op_desc->AddOutputDesc(tensor_desc->Clone()); + } + + return graph_->AddNode(op_desc); +} +void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) { + GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); +} +void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) { + GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); +} +} // namespace ut } // namespace ge diff --git a/tests/ut/parser/parser_ut_utils.h b/tests/ut/parser/parser_ut_utils.h index 38596b6..516a760 100644 --- a/tests/ut/parser/parser_ut_utils.h +++ b/tests/ut/parser/parser_ut_utils.h @@ -18,12 +18,31 @@ #define GE_PARSER_TESTS_UT_PARSER_H_ #include "framework/omg/parser/parser_inner_ctx.h" +#include "graph/compute_graph.h" namespace ge { class ParerUTestsUtils { public: static void ClearParserInnerCtx(); }; +namespace ut { +class GraphBuilder { + public: + explicit GraphBuilder(const std::string &name) { graph_ = std::make_shared(name); } + NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, + Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, + std::vector shape = {1, 1, 224, 224}); + void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx); + void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node); + ComputeGraphPtr GetGraph() { + graph_->TopologicalSorting(); + return graph_; + } + + private: + ComputeGraphPtr graph_; +}; +} // namespace ut } // namespace ge #endif // GE_PARSER_TESTS_UT_PARSER_H_ diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index 86f2b8f..2339c2d 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -17,7 +17,11 @@ #include #include #include "parser/common/op_parser_factory.h" +#define private public +#define protected public #include "parser/tensorflow/tensorflow_parser.h" +#undef protected +#undef private #include "framework/omg/parser/parser_factory.h" #include "graph/operator_reg.h" #include "external/graph/types.h" @@ -25,6 +29,7 @@ #include "parser/common/register_tbe.h" #include "external/parser/tensorflow_parser.h" #include "ut/parser/parser_ut_utils.h" +#include "graph/model.h" namespace ge { class UtestTensorflowParser : public testing::Test { @@ -119,4 +124,95 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto3) { EXPECT_EQ(ret, ge::SUCCESS); } +TEST_F(UtestTensorflowParser, tensorflow_parser_with_external_graph) { + auto make_graph = [](const string &name) { + auto builder = ut::GraphBuilder(name); + auto data1 = builder.AddNode(name + "_input1", "Data", 1, 1); + auto data2 = builder.AddNode(name + "_input2", "Data", 1, 1); + auto add = builder.AddNode(name + "_add", "Add", 2, 1); + auto net_output = builder.AddNode(name + "_net_output", "NetOutput", 1, 1); + builder.AddDataEdge(data1, 0, add, 0); + builder.AddDataEdge(data2, 0, add, 1); + builder.AddDataEdge(add, 0, net_output, 0); + return builder.GetGraph(); + }; + // 1. Create root graph + ComputeGraphPtr root_graph = make_graph("root_graph"); + + // 2. Create ONNX sub graph + // 2.1 Sub graph of onnx graph + ge::ComputeGraphPtr sub_sub_graph = ge::parser::MakeShared("sub_sub"); + // 2.2 ONNX graph + ComputeGraphPtr sub_graph = make_graph("sub_graph"); + auto add = sub_graph->FindNode("sub_graph_add"); + ASSERT_NE(add, nullptr); + add->GetOpDesc()->AddSubgraphName("sub_sub_graph"); + add->GetOpDesc()->SetSubgraphInstanceName(0, sub_sub_graph->GetName()); + sub_graph->AddSubGraph(sub_sub_graph); + auto input1 = sub_graph->FindNode("sub_graph_input1"); + ASSERT_NE(input1, nullptr); + AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_INDEX, 0); + auto input2 = sub_graph->FindNode("sub_graph_input2"); + ASSERT_NE(input2, nullptr); + AttrUtils::SetInt(input2->GetOpDesc(), ATTR_NAME_INDEX, 1); + + // 3. Serialize ONNX graph to string + // 3.1 normal + ge::Model model("model", ""); + model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph)); + Buffer buffer; + graphStatus save_ret = model.Save(buffer, false); + ASSERT_EQ(save_ret, GRAPH_SUCCESS); + std::string external_graph(reinterpret_cast(buffer.GetData()), buffer.GetSize()); + // model will failed + input1->GetOpDesc()->DelAttr(ATTR_NAME_INDEX); + ge::Model model_will_fail("model_will_fail", ""); + model_will_fail.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph)); + Buffer buffer_fail; + save_ret = model_will_fail.Save(buffer_fail, false); + ASSERT_EQ(save_ret, GRAPH_SUCCESS); + std::string external_graph_fail(reinterpret_cast(buffer_fail.GetData()), buffer_fail.GetSize()); + + // 4. Set string to function node + auto root_add = root_graph->FindNode("root_graph_add"); + ASSERT_NE(root_add, nullptr); + AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph); + auto root_input1 = root_graph->FindNode("root_graph_input1"); + ASSERT_NE(root_input1, nullptr); + AttrUtils::SetInt(root_input1->GetOpDesc(), ATTR_NAME_INDEX, 0); + auto root_input2 = root_graph->FindNode("root_graph_input2"); + ASSERT_NE(root_input2, nullptr); + AttrUtils::SetInt(root_input2->GetOpDesc(), ATTR_NAME_INDEX, 1); + + // 5. Run test (normal) + auto ret = TensorFlowModelParser::AddExternalGraph(root_graph); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(root_graph->GetAllSubgraphs().size(), 2); + EXPECT_EQ(sub_graph->GetAllSubgraphs().size(), 1); + EXPECT_NE(root_graph->GetSubgraph(sub_graph->GetName()), nullptr); + EXPECT_EQ(root_graph->GetSubgraph(sub_graph->GetName())->GetAllSubgraphs().size(), 0); + + // 6. Run test (failed) + // 6.1 Failed to load model + AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", "dummy string"); + ret = TensorFlowModelParser::AddExternalGraph(root_graph); + EXPECT_EQ(ret, INTERNAL_ERROR); + + // 6.2 Failed to map sub graph + AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph_fail); + ret = TensorFlowModelParser::AddExternalGraph(root_graph); + EXPECT_EQ(ret, INTERNAL_ERROR); + + // 6.3 Failed to set sub graph to node + AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph); + root_add->SetOwnerComputeGraph(nullptr); + ret = TensorFlowModelParser::AddExternalGraph(root_graph); + EXPECT_EQ(ret, INTERNAL_ERROR); + + // 6.4 Failed to add sub sub graph + root_add->SetOwnerComputeGraph(nullptr); + root_graph->RemoveSubGraph(sub_graph); + ret = TensorFlowModelParser::AddExternalGraph(root_graph); + EXPECT_EQ(ret, INTERNAL_ERROR); +} } // namespace ge \ No newline at end of file