From 18fa76d80ef4b52b8a156e6f85fe4293dfa9aefe Mon Sep 17 00:00:00 2001 From: medivh-x Date: Mon, 31 May 2021 10:30:13 +0800 Subject: [PATCH] model parser now support serialized proto input --- parser/tensorflow/tensorflow_parser.cc | 63 +++++++++++++++++++ parser/tensorflow/tensorflow_parser.h | 21 +++++++ .../tensorflow_parser_unittest.cc | 31 +++++++++ 3 files changed, 115 insertions(+) diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 74ce961..9352885 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -2392,6 +2392,69 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Mes return SUCCESS; } +Status TensorFlowModelParser::ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) { + if (serialized_proto.empty()) { + GELOGE(FAILED, "Deserialize proto failed as serialized proto is empty"); + return FAILED; + } + domi::tensorflow::GraphDef graph_def; + if (!graph_def.ParseFromString(serialized_proto)) { + GELOGE(FAILED, "Proto object GraphDef parse serialized proto failed"); + return FAILED; + } + return ParseProto(reinterpret_cast(&graph_def), graph); +} + +Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, + domi::GetGraphCallbackV2 callback, + ge::ComputeGraphPtr &root_graph) { + ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser); + ErrorManager::GetInstance().GenWorkStreamIdDefault(); + GE_CHECK_NOTNULL(callback); + GE_CHECK_NOTNULL(root_graph); + + PARSER_TIMESTAMP_START(ParseProtoWithSubgraph); + std::deque tasks; + tasks.push_back({nullptr, "root", nullptr, "", root_graph}); + bool root_parsed = false; + + while (!tasks.empty()) { + auto arg = tasks.front(); + tasks.pop_front(); + + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW); + + Status ret = SUCCESS; + if (root_parsed) { + GELOGI("Begin to parse serialized proto of sub graph %s", arg.function_name.c_str()); + ret = model_parser->ParseProto(callback(arg.function_name), arg.graph); + } else { + GELOGI("Begin to parse serialized proto of root graph"); + ret = model_parser->ParseProto(root_proto, arg.graph); + root_parsed = true; + } + + if (ret != SUCCESS) { + GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(), + arg.graph->GetName().c_str()); + return ret; + } + + ret = PostOpProcessForSubgraph(arg); + if (ret != SUCCESS) { + return ret; // the error log has been printed inner the function + } + + ret = GenSubgraphParseTasks(arg.graph, tasks); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to gen tasks for sub graph of graph %s", arg.graph->GetName().c_str()); + return ret; + } + } + PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); + return SUCCESS; +} + // For the identity operator whose output is "_retval", optimize it. Status TensorFlowModelParser::OptimizeIdentityByOutput(map &nodedef_map, const string &curr_node_name, bool &clear_input_flag) { diff --git a/parser/tensorflow/tensorflow_parser.h b/parser/tensorflow/tensorflow_parser.h index 159e338..5ecf9e6 100644 --- a/parser/tensorflow/tensorflow_parser.h +++ b/parser/tensorflow/tensorflow_parser.h @@ -138,6 +138,27 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override ; + /** + * @ingroup domi_omg + * @brief Analyze network model data + * @param [in] proto serialized network model + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + Status ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) override; + + /** + * @ingroup domi_omg + * @brief Analyze callback model data in subgraph + * @param [in] proto serialized network model + * @param [in] callback callback of subgraph + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + Status ParseProtoWithSubgraph(const std::string &serialized_proto, domi::GetGraphCallbackV2 callback, + ge::ComputeGraphPtr &graph) override; private: Status Parse(const char *file, ge::ComputeGraphPtr &graph); diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index 9a7d910..f597289 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -17,6 +17,8 @@ #include #include #include "parser/common/op_parser_factory.h" +#include "parser/tensorflow/tensorflow_parser.h" +#include "framework/omg/parser/parser_factory.h" #include "graph/operator_reg.h" #include "external/graph/types.h" #include "register/op_registry.h" @@ -85,5 +87,34 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) { EXPECT_EQ(ret, domi::SUCCESS); } +TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto1) { + ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared("tmpGraph"); + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); + ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(std::string(""), + [](std::string)->std::string{ return "";}, compute_graph); + EXPECT_NE(ret, ge::SUCCESS); +} + +TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto2) { + ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared("tmpGraph"); + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); + ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(std::string("null"), + [](std::string)->std::string{ return "";}, compute_graph); + EXPECT_NE(ret, ge::SUCCESS); +} + +TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto3) { + ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared("tmpGraph"); + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); + + domi::tensorflow::GraphDef graph_def; + auto arg_node = graph_def.add_node(); + arg_node->set_name("noop"); + arg_node->set_op("NoOp"); + + ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(graph_def.SerializeAsString(), + [](std::string)->std::string{ return "";}, compute_graph); + EXPECT_EQ(ret, ge::SUCCESS); +} } // namespace ge \ No newline at end of file