diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index f188a97..25edff9 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -54,6 +54,7 @@ #include "register/op_registry.h" #include "register/scope/scope_graph_impl.h" #include "register/scope/scope_pass_registry_impl.h" +#include "parser/common/auto_mapping_subgraph_io_index_func.h" using ge::const_op_update_vec; using ge::OpParserFactory; @@ -302,6 +303,80 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) { } return SUCCESS; } + +Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, ComputeGraphPtr &root_graph){ + // Inner function, input params have been checked by caller + Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(graph, + [](int in, int &out)->Status { + out = in; + return SUCCESS; + }, + [](int in, int &out)->Status { + out = in; + return SUCCESS; + }); + if (status != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]node:%s, sub graph name:%s.", + node->GetName().c_str(), graph.GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to map sub graph input and output, node:%s, sub graph name:%s.", + node->GetName().c_str(), graph.GetName().c_str()); + return INTERNAL_ERROR; + } + + ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + // Inner function, GetOpDesc has been checked by caller + (void)node->GetOpDesc()->AddSubgraphName("f"); + auto ret = NodeUtils::SetSubgraph(*node, 0, compute_graph); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Set][Subgraph]Node:%s, sub graph name:%s.", + node->GetName().c_str(), compute_graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to set sub graph, node: %s, sub graph name: %s.", + node->GetName().c_str(), compute_graph->GetName().c_str()); + return INTERNAL_ERROR; + } + for (const auto &sub_graph : compute_graph->GetAllSubgraphs()) { + ret = root_graph->AddSubGraph(sub_graph); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Add][Subgraph]Node:%s, sub graph name:%s, sub sub graph name:%s.", + node->GetName().c_str(), compute_graph->GetName().c_str(), sub_graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to add sub graph to root graph, node:%s, sub graph name:%s.", + node->GetName().c_str(), sub_graph->GetName().c_str()); + return INTERNAL_ERROR; + } + compute_graph->RemoveSubgraph(sub_graph); + GELOGD("Add subgraph[%s] to root graph[%s].", sub_graph->GetName().c_str(), root_graph->GetName().c_str()); + } + return SUCCESS; +} + +Status AddExternalGraph(ComputeGraphPtr &root_graph) { + GE_CHECK_NOTNULL(root_graph); + for (const NodePtr &node : root_graph->GetAllNodes()) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + continue; + } + std::string model_data; + if (AttrUtils::GetStr(node->GetOpDesc(), "_external_model", model_data) && !model_data.empty()) { + ge::Model model; + auto load_ret = ge::Model::Load(reinterpret_cast(model_data.data()), model_data.size(), model); + if (load_ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Parse][ExternalModel]Node:%s.", node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to parse external model, node:%s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + Graph graph = model.GetGraph(); + GELOGD("Get subgraph[%s] from model[%s].", graph.GetName().c_str(), node->GetName().c_str()); + Status ret = MappingAndAddSubGraph(node, graph, root_graph); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]Node:%s.", node->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to map and add sub graph, node:%s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + } + } + return SUCCESS; +} } // namespace /** @@ -2397,6 +2472,11 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Mes return ret; } } + auto add_ret = AddExternalGraph(root_graph); + if (add_ret != SUCCESS) { + GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str()); + return add_ret; + } PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); return SUCCESS; } @@ -2460,6 +2540,11 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_pro return ret; } } + auto add_ret = AddExternalGraph(root_graph); + if (add_ret != SUCCESS) { + GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str()); + return add_ret; + } PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); return SUCCESS; }