| @@ -2392,6 +2392,69 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Mes | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowModelParser::ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) { | |||||
| if (serialized_proto.empty()) { | |||||
| GELOGE(FAILED, "Deserialize proto failed as serialized proto is empty"); | |||||
| return FAILED; | |||||
| } | |||||
| domi::tensorflow::GraphDef graph_def; | |||||
| if (!graph_def.ParseFromString(serialized_proto)) { | |||||
| GELOGE(FAILED, "Proto object GraphDef parse serialized proto failed"); | |||||
| return FAILED; | |||||
| } | |||||
| return ParseProto(reinterpret_cast<const google::protobuf::Message *>(&graph_def), graph); | |||||
| } | |||||
| Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, | |||||
| domi::GetGraphCallbackV2 callback, | |||||
| ge::ComputeGraphPtr &root_graph) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser); | |||||
| ErrorManager::GetInstance().GenWorkStreamIdDefault(); | |||||
| GE_CHECK_NOTNULL(callback); | |||||
| GE_CHECK_NOTNULL(root_graph); | |||||
| PARSER_TIMESTAMP_START(ParseProtoWithSubgraph); | |||||
| std::deque<ParseArg> tasks; | |||||
| tasks.push_back({nullptr, "root", nullptr, "", root_graph}); | |||||
| bool root_parsed = false; | |||||
| while (!tasks.empty()) { | |||||
| auto arg = tasks.front(); | |||||
| tasks.pop_front(); | |||||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW); | |||||
| Status ret = SUCCESS; | |||||
| if (root_parsed) { | |||||
| GELOGI("Begin to parse serialized proto of sub graph %s", arg.function_name.c_str()); | |||||
| ret = model_parser->ParseProto(callback(arg.function_name), arg.graph); | |||||
| } else { | |||||
| GELOGI("Begin to parse serialized proto of root graph"); | |||||
| ret = model_parser->ParseProto(root_proto, arg.graph); | |||||
| root_parsed = true; | |||||
| } | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(), | |||||
| arg.graph->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| ret = PostOpProcessForSubgraph(arg); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; // the error log has been printed inner the function | |||||
| } | |||||
| ret = GenSubgraphParseTasks(arg.graph, tasks); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Failed to gen tasks for sub graph of graph %s", arg.graph->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); | |||||
| return SUCCESS; | |||||
| } | |||||
| // For the identity operator whose output is "_retval", optimize it. | // For the identity operator whose output is "_retval", optimize it. | ||||
| Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, | Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, | ||||
| const string &curr_node_name, bool &clear_input_flag) { | const string &curr_node_name, bool &clear_input_flag) { | ||||
| @@ -138,6 +138,27 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override ; | Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override ; | ||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Analyze network model data | |||||
| * @param [in] proto serialized network model | |||||
| * @param [in|out] graph Save the network information after analysis | |||||
| * @return SUCCESS | |||||
| * @return Others failed | |||||
| */ | |||||
| Status ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) override; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Analyze callback model data in subgraph | |||||
| * @param [in] proto serialized network model | |||||
| * @param [in] callback callback of subgraph | |||||
| * @param [in|out] graph Save the network information after analysis | |||||
| * @return SUCCESS | |||||
| * @return Others failed | |||||
| */ | |||||
| Status ParseProtoWithSubgraph(const std::string &serialized_proto, domi::GetGraphCallbackV2 callback, | |||||
| ge::ComputeGraphPtr &graph) override; | |||||
| private: | private: | ||||
| Status Parse(const char *file, ge::ComputeGraphPtr &graph); | Status Parse(const char *file, ge::ComputeGraphPtr &graph); | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "parser/tensorflow/tensorflow_parser.h" | |||||
| #include "framework/omg/parser/parser_factory.h" | |||||
| #include "graph/operator_reg.h" | #include "graph/operator_reg.h" | ||||
| #include "external/graph/types.h" | #include "external/graph/types.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| @@ -85,5 +87,34 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | |||||
| EXPECT_EQ(ret, domi::SUCCESS); | EXPECT_EQ(ret, domi::SUCCESS); | ||||
| } | } | ||||
| TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto1) { | |||||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph"); | |||||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); | |||||
| ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(std::string(""), | |||||
| [](std::string)->std::string{ return "";}, compute_graph); | |||||
| EXPECT_NE(ret, ge::SUCCESS); | |||||
| } | |||||
| TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto2) { | |||||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph"); | |||||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); | |||||
| ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(std::string("null"), | |||||
| [](std::string)->std::string{ return "";}, compute_graph); | |||||
| EXPECT_NE(ret, ge::SUCCESS); | |||||
| } | |||||
| TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto3) { | |||||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph"); | |||||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); | |||||
| domi::tensorflow::GraphDef graph_def; | |||||
| auto arg_node = graph_def.add_node(); | |||||
| arg_node->set_name("noop"); | |||||
| arg_node->set_op("NoOp"); | |||||
| ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(graph_def.SerializeAsString(), | |||||
| [](std::string)->std::string{ return "";}, compute_graph); | |||||
| EXPECT_EQ(ret, ge::SUCCESS); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||