Browse Source

!369 add function for node with external subgraph

Merge pull request !369 from Fu Jingguo/master
pull/370/head
i-robot Gitee 4 years ago
parent
commit
c5a1007bbf
5 changed files with 228 additions and 0 deletions
  1. +86
    -0
      parser/tensorflow/tensorflow_parser.cc
  2. +1
    -0
      parser/tensorflow/tensorflow_parser.h
  3. +26
    -0
      tests/ut/parser/parser_ut_utils.cc
  4. +19
    -0
      tests/ut/parser/parser_ut_utils.h
  5. +96
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 86
- 0
parser/tensorflow/tensorflow_parser.cc View File

@@ -54,6 +54,7 @@
#include "register/op_registry.h" #include "register/op_registry.h"
#include "register/scope/scope_graph_impl.h" #include "register/scope/scope_graph_impl.h"
#include "register/scope/scope_pass_registry_impl.h" #include "register/scope/scope_pass_registry_impl.h"
#include "parser/common/auto_mapping_subgraph_io_index_func.h"


using ge::const_op_update_vec; using ge::const_op_update_vec;
using ge::OpParserFactory; using ge::OpParserFactory;
@@ -210,6 +211,7 @@ const std::vector<std::string> kMakeOperatorNotByIr = {ge::parser::ARG, ge::pars
const char *const kDpop = "DPOP"; const char *const kDpop = "DPOP";
const char *const kFuncDefLibraryFilePath = "graph_def_library.pbtxt"; const char *const kFuncDefLibraryFilePath = "graph_def_library.pbtxt";
const char *const kAttrNameIsScopeInnerNode = "_is_scope_inner_node"; const char *const kAttrNameIsScopeInnerNode = "_is_scope_inner_node";
const char *const kExternalModel = "_external_model";
struct ParseArg { struct ParseArg {
const google::protobuf::Message *proto; const google::protobuf::Message *proto;
std::string function_name; std::string function_name;
@@ -302,6 +304,52 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) {
} }
return SUCCESS; return SUCCESS;
} }

Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, ComputeGraphPtr &root_graph){
// Inner function, input params have been checked by caller
Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(graph,
[](int in, int &out)->Status {
out = in;
return SUCCESS;
},
[](int in, int &out)->Status {
out = in;
return SUCCESS;
});
if (status != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]node:%s, sub graph name:%s.",
node->GetName().c_str(), graph.GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to map sub graph input and output, node:%s, sub graph name:%s.",
node->GetName().c_str(), graph.GetName().c_str());
return INTERNAL_ERROR;
}

ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
// Inner function, GetOpDesc has been checked by caller
(void)node->GetOpDesc()->AddSubgraphName("f");
auto ret = NodeUtils::SetSubgraph(*node, 0, compute_graph);
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Set][Subgraph]Node:%s, sub graph name:%s.",
node->GetName().c_str(), compute_graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to set sub graph, node: %s, sub graph name: %s.",
node->GetName().c_str(), compute_graph->GetName().c_str());
return INTERNAL_ERROR;
}
for (const auto &sub_graph : compute_graph->GetAllSubgraphs()) {
ret = root_graph->AddSubgraph(sub_graph);
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Add][Subgraph]Node:%s, sub graph name:%s, sub sub graph name:%s.",
node->GetName().c_str(), compute_graph->GetName().c_str(), sub_graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to add sub graph to root graph, node:%s, sub graph name:%s.",
node->GetName().c_str(), sub_graph->GetName().c_str());
return INTERNAL_ERROR;
}
compute_graph->RemoveSubgraph(sub_graph);
GELOGD("Add subgraph[%s] to root graph[%s].", sub_graph->GetName().c_str(), root_graph->GetName().c_str());
}
return SUCCESS;
}
} // namespace } // namespace


/** /**
@@ -2397,6 +2445,11 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Mes
return ret; return ret;
} }
} }
auto add_ret = AddExternalGraph(root_graph);
if (add_ret != SUCCESS) {
GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str());
return add_ret;
}
PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph");
return SUCCESS; return SUCCESS;
} }
@@ -2460,6 +2513,11 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_pro
return ret; return ret;
} }
} }
auto add_ret = AddExternalGraph(root_graph);
if (add_ret != SUCCESS) {
GELOGE(add_ret, "Failed to add external graph for root graph %s.", root_graph->GetName().c_str());
return add_ret;
}
PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph");
return SUCCESS; return SUCCESS;
} }
@@ -4034,6 +4092,34 @@ Status TensorFlowModelParser::UpdateOutputsInfo(const ParserUtils::OutputMapping
} }
return SUCCESS; return SUCCESS;
} }

Status TensorFlowModelParser::AddExternalGraph(ComputeGraphPtr &root_graph) {
GE_CHECK_NOTNULL(root_graph);
for (const NodePtr &node : root_graph->GetAllNodes()) {
if (node == nullptr || node->GetOpDesc() == nullptr) {
continue;
}
std::string model_data;
if (AttrUtils::GetStr(node->GetOpDesc(), kExternalModel, model_data) && !model_data.empty()) {
ge::Model model;
auto load_ret = ge::Model::Load(reinterpret_cast<const uint8_t *>(model_data.data()), model_data.size(), model);
if (load_ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Parse][ExternalModel]Node:%s.", node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to parse external model, node:%s.", node->GetName().c_str());
return INTERNAL_ERROR;
}
Graph graph = model.GetGraph();
GELOGD("Get subgraph[%s] from model[%s].", graph.GetName().c_str(), node->GetName().c_str());
Status ret = MappingAndAddSubGraph(node, graph, root_graph);
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]Node:%s.", node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to map and add sub graph, node:%s.", node->GetName().c_str());
return INTERNAL_ERROR;
}
}
}
return SUCCESS;
}
} // namespace ge } // namespace ge


namespace domi { namespace domi {


+ 1
- 0
parser/tensorflow/tensorflow_parser.h View File

@@ -649,6 +649,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr<OpParser> &op_parser); Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr<OpParser> &op_parser);
Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph); Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph);
static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes); static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes);
static Status AddExternalGraph(ComputeGraphPtr &root_graph);


/** /**
* save <node_name, node_def> * save <node_name, node_def>


+ 26
- 0
tests/ut/parser/parser_ut_utils.cc View File

@@ -16,6 +16,7 @@


#include "ut/parser/parser_ut_utils.h" #include "ut/parser/parser_ut_utils.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "graph/utils/graph_utils.h"


namespace ge { namespace ge {
void ParerUTestsUtils::ClearParserInnerCtx() { void ParerUTestsUtils::ClearParserInnerCtx() {
@@ -41,4 +42,29 @@ void ParerUTestsUtils::ClearParserInnerCtx() {
ge::GetParserContext().enable_scope_fusion_passes = ""; ge::GetParserContext().enable_scope_fusion_passes = "";
GELOGI("Clear parser inner context successfully."); GELOGI("Clear parser inner context successfully.");
} }
namespace ut {
NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format,
DataType data_type, std::vector<int64_t> shape) {
auto tensor_desc = std::make_shared<GeTensorDesc>();
tensor_desc->SetShape(GeShape(std::move(shape)));
tensor_desc->SetFormat(format);
tensor_desc->SetDataType(data_type);

auto op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < in_cnt; ++i) {
op_desc->AddInputDesc(tensor_desc->Clone());
}
for (int i = 0; i < out_cnt; ++i) {
op_desc->AddOutputDesc(tensor_desc->Clone());
}

return graph_->AddNode(op_desc);
}
void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) {
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx));
}
void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) {
GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor());
}
} // namespace ut
} // namespace ge } // namespace ge

+ 19
- 0
tests/ut/parser/parser_ut_utils.h View File

@@ -18,12 +18,31 @@
#define GE_PARSER_TESTS_UT_PARSER_H_ #define GE_PARSER_TESTS_UT_PARSER_H_


#include "framework/omg/parser/parser_inner_ctx.h" #include "framework/omg/parser/parser_inner_ctx.h"
#include "graph/compute_graph.h"


namespace ge { namespace ge {
class ParerUTestsUtils { class ParerUTestsUtils {
public: public:
static void ClearParserInnerCtx(); static void ClearParserInnerCtx();
}; };
namespace ut {
class GraphBuilder {
public:
explicit GraphBuilder(const std::string &name) { graph_ = std::make_shared<ComputeGraph>(name); }
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt,
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT,
std::vector<int64_t> shape = {1, 1, 224, 224});
void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx);
void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node);
ComputeGraphPtr GetGraph() {
graph_->TopologicalSorting();
return graph_;
}

private:
ComputeGraphPtr graph_;
};
} // namespace ut
} // namespace ge } // namespace ge


#endif // GE_PARSER_TESTS_UT_PARSER_H_ #endif // GE_PARSER_TESTS_UT_PARSER_H_

+ 96
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -17,7 +17,11 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#define private public
#define protected public
#include "parser/tensorflow/tensorflow_parser.h" #include "parser/tensorflow/tensorflow_parser.h"
#undef protected
#undef private
#include "framework/omg/parser/parser_factory.h" #include "framework/omg/parser/parser_factory.h"
#include "graph/operator_reg.h" #include "graph/operator_reg.h"
#include "external/graph/types.h" #include "external/graph/types.h"
@@ -25,6 +29,7 @@
#include "parser/common/register_tbe.h" #include "parser/common/register_tbe.h"
#include "external/parser/tensorflow_parser.h" #include "external/parser/tensorflow_parser.h"
#include "ut/parser/parser_ut_utils.h" #include "ut/parser/parser_ut_utils.h"
#include "graph/model.h"


namespace ge { namespace ge {
class UtestTensorflowParser : public testing::Test { class UtestTensorflowParser : public testing::Test {
@@ -119,4 +124,95 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_with_serialized_proto3) {
EXPECT_EQ(ret, ge::SUCCESS); EXPECT_EQ(ret, ge::SUCCESS);
} }


TEST_F(UtestTensorflowParser, tensorflow_parser_with_external_graph) {
auto make_graph = [](const string &name) {
auto builder = ut::GraphBuilder(name);
auto data1 = builder.AddNode(name + "_input1", "Data", 1, 1);
auto data2 = builder.AddNode(name + "_input2", "Data", 1, 1);
auto add = builder.AddNode(name + "_add", "Add", 2, 1);
auto net_output = builder.AddNode(name + "_net_output", "NetOutput", 1, 1);
builder.AddDataEdge(data1, 0, add, 0);
builder.AddDataEdge(data2, 0, add, 1);
builder.AddDataEdge(add, 0, net_output, 0);
return builder.GetGraph();
};
// 1. Create root graph
ComputeGraphPtr root_graph = make_graph("root_graph");

// 2. Create ONNX sub graph
// 2.1 Sub graph of onnx graph
ge::ComputeGraphPtr sub_sub_graph = ge::parser::MakeShared<ge::ComputeGraph>("sub_sub");
// 2.2 ONNX graph
ComputeGraphPtr sub_graph = make_graph("sub_graph");
auto add = sub_graph->FindNode("sub_graph_add");
ASSERT_NE(add, nullptr);
add->GetOpDesc()->AddSubgraphName("sub_sub_graph");
add->GetOpDesc()->SetSubgraphInstanceName(0, sub_sub_graph->GetName());
sub_graph->AddSubGraph(sub_sub_graph);
auto input1 = sub_graph->FindNode("sub_graph_input1");
ASSERT_NE(input1, nullptr);
AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_INDEX, 0);
auto input2 = sub_graph->FindNode("sub_graph_input2");
ASSERT_NE(input2, nullptr);
AttrUtils::SetInt(input2->GetOpDesc(), ATTR_NAME_INDEX, 1);

// 3. Serialize ONNX graph to string
// 3.1 normal
ge::Model model("model", "");
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph));
Buffer buffer;
graphStatus save_ret = model.Save(buffer, false);
ASSERT_EQ(save_ret, GRAPH_SUCCESS);
std::string external_graph(reinterpret_cast<const char *>(buffer.GetData()), buffer.GetSize());
// model will failed
input1->GetOpDesc()->DelAttr(ATTR_NAME_INDEX);
ge::Model model_will_fail("model_will_fail", "");
model_will_fail.SetGraph(GraphUtils::CreateGraphFromComputeGraph(sub_graph));
Buffer buffer_fail;
save_ret = model_will_fail.Save(buffer_fail, false);
ASSERT_EQ(save_ret, GRAPH_SUCCESS);
std::string external_graph_fail(reinterpret_cast<const char *>(buffer_fail.GetData()), buffer_fail.GetSize());

// 4. Set string to function node
auto root_add = root_graph->FindNode("root_graph_add");
ASSERT_NE(root_add, nullptr);
AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph);
auto root_input1 = root_graph->FindNode("root_graph_input1");
ASSERT_NE(root_input1, nullptr);
AttrUtils::SetInt(root_input1->GetOpDesc(), ATTR_NAME_INDEX, 0);
auto root_input2 = root_graph->FindNode("root_graph_input2");
ASSERT_NE(root_input2, nullptr);
AttrUtils::SetInt(root_input2->GetOpDesc(), ATTR_NAME_INDEX, 1);

// 5. Run test (normal)
auto ret = TensorFlowModelParser::AddExternalGraph(root_graph);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(root_graph->GetAllSubgraphs().size(), 2);
EXPECT_EQ(sub_graph->GetAllSubgraphs().size(), 1);
EXPECT_NE(root_graph->GetSubgraph(sub_graph->GetName()), nullptr);
EXPECT_EQ(root_graph->GetSubgraph(sub_graph->GetName())->GetAllSubgraphs().size(), 0);

// 6. Run test (failed)
// 6.1 Failed to load model
AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", "dummy string");
ret = TensorFlowModelParser::AddExternalGraph(root_graph);
EXPECT_EQ(ret, INTERNAL_ERROR);

// 6.2 Failed to map sub graph
AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph_fail);
ret = TensorFlowModelParser::AddExternalGraph(root_graph);
EXPECT_EQ(ret, INTERNAL_ERROR);

// 6.3 Failed to set sub graph to node
AttrUtils::SetStr(root_add->GetOpDesc(), "_external_model", external_graph);
root_add->SetOwnerComputeGraph(nullptr);
ret = TensorFlowModelParser::AddExternalGraph(root_graph);
EXPECT_EQ(ret, INTERNAL_ERROR);

// 6.4 Failed to add sub sub graph
root_add->SetOwnerComputeGraph(nullptr);
root_graph->RemoveSubGraph(sub_graph);
ret = TensorFlowModelParser::AddExternalGraph(root_graph);
EXPECT_EQ(ret, INTERNAL_ERROR);
}
} // namespace ge } // namespace ge

Loading…
Cancel
Save