Browse Source

!308 Model parser 支持序列化字符串输入

Merge pull request !308 from 薛鹏/master
pull/308/MERGE
i-robot Gitee 4 years ago
parent
commit
2597604991
3 changed files with 115 additions and 0 deletions
  1. +63
    -0
      parser/tensorflow/tensorflow_parser.cc
  2. +21
    -0
      parser/tensorflow/tensorflow_parser.h
  3. +31
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 63
- 0
parser/tensorflow/tensorflow_parser.cc View File

@@ -2392,6 +2392,69 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Mes
return SUCCESS;
}

Status TensorFlowModelParser::ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) {
if (serialized_proto.empty()) {
GELOGE(FAILED, "Deserialize proto failed as serialized proto is empty");
return FAILED;
}
domi::tensorflow::GraphDef graph_def;
if (!graph_def.ParseFromString(serialized_proto)) {
GELOGE(FAILED, "Proto object GraphDef parse serialized proto failed");
return FAILED;
}
return ParseProto(reinterpret_cast<const google::protobuf::Message *>(&graph_def), graph);
}

Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto,
domi::GetGraphCallbackV2 callback,
ge::ComputeGraphPtr &root_graph) {
ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser);
ErrorManager::GetInstance().GenWorkStreamIdDefault();
GE_CHECK_NOTNULL(callback);
GE_CHECK_NOTNULL(root_graph);

PARSER_TIMESTAMP_START(ParseProtoWithSubgraph);
std::deque<ParseArg> tasks;
tasks.push_back({nullptr, "root", nullptr, "", root_graph});
bool root_parsed = false;

while (!tasks.empty()) {
auto arg = tasks.front();
tasks.pop_front();

auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW);

Status ret = SUCCESS;
if (root_parsed) {
GELOGI("Begin to parse serialized proto of sub graph %s", arg.function_name.c_str());
ret = model_parser->ParseProto(callback(arg.function_name), arg.graph);
} else {
GELOGI("Begin to parse serialized proto of root graph");
ret = model_parser->ParseProto(root_proto, arg.graph);
root_parsed = true;
}

if (ret != SUCCESS) {
GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(),
arg.graph->GetName().c_str());
return ret;
}

ret = PostOpProcessForSubgraph(arg);
if (ret != SUCCESS) {
return ret; // the error log has been printed inner the function
}

ret = GenSubgraphParseTasks(arg.graph, tasks);
if (ret != SUCCESS) {
GELOGE(ret, "Failed to gen tasks for sub graph of graph %s", arg.graph->GetName().c_str());
return ret;
}
}
PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph");
return SUCCESS;
}

// For the identity operator whose output is "_retval", optimize it.
Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map,
const string &curr_node_name, bool &clear_input_flag) {


+ 21
- 0
parser/tensorflow/tensorflow_parser.h View File

@@ -138,6 +138,27 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {

Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override ;

/**
* @ingroup domi_omg
* @brief Analyze network model data
* @param [in] proto serialized network model
* @param [in|out] graph Save the network information after analysis
* @return SUCCESS
* @return Others failed
*/
Status ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) override;

/**
* @ingroup domi_omg
* @brief Analyze callback model data in subgraph
* @param [in] proto serialized network model
* @param [in] callback callback of subgraph
* @param [in|out] graph Save the network information after analysis
* @return SUCCESS
* @return Others failed
*/
Status ParseProtoWithSubgraph(const std::string &serialized_proto, domi::GetGraphCallbackV2 callback,
ge::ComputeGraphPtr &graph) override;
private:
Status Parse(const char *file, ge::ComputeGraphPtr &graph);



+ 31
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -17,6 +17,8 @@
#include <gtest/gtest.h>
#include <iostream>
#include "parser/common/op_parser_factory.h"
#include "parser/tensorflow/tensorflow_parser.h"
#include "framework/omg/parser/parser_factory.h"
#include "graph/operator_reg.h"
#include "external/graph/types.h"
#include "register/op_registry.h"
@@ -85,5 +87,34 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_success) {
EXPECT_EQ(ret, domi::SUCCESS);
}

TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto1) {
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW);
ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(std::string(""),
[](std::string)->std::string{ return "";}, compute_graph);
EXPECT_NE(ret, ge::SUCCESS);
}

TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto2) {
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW);
ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(std::string("null"),
[](std::string)->std::string{ return "";}, compute_graph);
EXPECT_NE(ret, ge::SUCCESS);
}

TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto3) {
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW);

domi::tensorflow::GraphDef graph_def;
auto arg_node = graph_def.add_node();
arg_node->set_name("noop");
arg_node->set_op("NoOp");

ge::graphStatus ret = model_parser->ParseProtoWithSubgraph(graph_def.SerializeAsString(),
[](std::string)->std::string{ return "";}, compute_graph);
EXPECT_EQ(ret, ge::SUCCESS);
}

} // namespace ge

Loading…
Cancel
Save