| @@ -2392,6 +2392,69 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Mes | |||
| return SUCCESS; | |||
| } | |||
| Status TensorFlowModelParser::ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) { | |||
| if (serialized_proto.empty()) { | |||
| GELOGE(FAILED, "Deserialize proto failed as serialized proto is empty"); | |||
| return FAILED; | |||
| } | |||
| domi::tensorflow::GraphDef graph_def; | |||
| if (!graph_def.ParseFromString(serialized_proto)) { | |||
| GELOGE(FAILED, "Proto object GraphDef parse serialized proto failed"); | |||
| return FAILED; | |||
| } | |||
| return ParseProto(reinterpret_cast<const google::protobuf::Message *>(&graph_def), graph); | |||
| } | |||
| Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, | |||
| domi::GetGraphCallbackV2 callback, | |||
| ge::ComputeGraphPtr &root_graph) { | |||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser); | |||
| ErrorManager::GetInstance().GenWorkStreamIdDefault(); | |||
| GE_CHECK_NOTNULL(callback); | |||
| GE_CHECK_NOTNULL(root_graph); | |||
| PARSER_TIMESTAMP_START(ParseProtoWithSubgraph); | |||
| std::deque<ParseArg> tasks; | |||
| tasks.push_back({nullptr, "root", nullptr, "", root_graph}); | |||
| bool root_parsed = false; | |||
| while (!tasks.empty()) { | |||
| auto arg = tasks.front(); | |||
| tasks.pop_front(); | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW); | |||
| Status ret = SUCCESS; | |||
| if (root_parsed) { | |||
| GELOGI("Begin to parse serialized proto of sub graph %s", arg.function_name.c_str()); | |||
| ret = model_parser->ParseProto(callback(arg.function_name), arg.graph); | |||
| } else { | |||
| GELOGI("Begin to parse serialized proto of root graph"); | |||
| ret = model_parser->ParseProto(root_proto, arg.graph); | |||
| root_parsed = true; | |||
| } | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(), | |||
| arg.graph->GetName().c_str()); | |||
| return ret; | |||
| } | |||
| ret = PostOpProcessForSubgraph(arg); | |||
| if (ret != SUCCESS) { | |||
| return ret; // the error log has been printed inner the function | |||
| } | |||
| ret = GenSubgraphParseTasks(arg.graph, tasks); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to gen tasks for sub graph of graph %s", arg.graph->GetName().c_str()); | |||
| return ret; | |||
| } | |||
| } | |||
| PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); | |||
| return SUCCESS; | |||
| } | |||
| // For the identity operator whose output is "_retval", optimize it. | |||
| Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, | |||
| const string &curr_node_name, bool &clear_input_flag) { | |||
| @@ -138,6 +138,27 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||
| Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override ; | |||
| /** | |||
| * @ingroup domi_omg | |||
| * @brief Analyze network model data | |||
| * @param [in] proto serialized network model | |||
| * @param [in|out] graph Save the network information after analysis | |||
| * @return SUCCESS | |||
| * @return Others failed | |||
| */ | |||
| Status ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) override; | |||
| /** | |||
| * @ingroup domi_omg | |||
| * @brief Analyze callback model data in subgraph | |||
| * @param [in] proto serialized network model | |||
| * @param [in] callback callback of subgraph | |||
| * @param [in|out] graph Save the network information after analysis | |||
| * @return SUCCESS | |||
| * @return Others failed | |||
| */ | |||
| Status ParseProtoWithSubgraph(const std::string &serialized_proto, domi::GetGraphCallbackV2 callback, | |||
| ge::ComputeGraphPtr &graph) override; | |||
| private: | |||
| Status Parse(const char *file, ge::ComputeGraphPtr &graph); | |||
| @@ -17,6 +17,8 @@ | |||
| #include <gtest/gtest.h> | |||
| #include <iostream> | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "parser/tensorflow/tensorflow_parser.h" | |||
| #include "framework/omg/parser/parser_factory.h" | |||
| #include "graph/operator_reg.h" | |||
| #include "external/graph/types.h" | |||
| #include "register/op_registry.h" | |||
| @@ -85,5 +87,34 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | |||
| EXPECT_EQ(ret, domi::SUCCESS); | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto1) { | |||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph"); | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); | |||
| ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(std::string(""), | |||
| [](std::string)->std::string{ return "";}, compute_graph); | |||
| EXPECT_NE(ret, ge::SUCCESS); | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto2) { | |||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph"); | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); | |||
| ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(std::string("null"), | |||
| [](std::string)->std::string{ return "";}, compute_graph); | |||
| EXPECT_NE(ret, ge::SUCCESS); | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto3) { | |||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph"); | |||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); | |||
| domi::tensorflow::GraphDef graph_def; | |||
| auto arg_node = graph_def.add_node(); | |||
| arg_node->set_name("noop"); | |||
| arg_node->set_op("NoOp"); | |||
| ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(graph_def.SerializeAsString(), | |||
| [](std::string)->std::string{ return "";}, compute_graph); | |||
| EXPECT_EQ(ret, ge::SUCCESS); | |||
| } | |||
| } // namespace ge | |||