Browse Source

parser one to many

pull/124/head
y00500818 5 years ago
parent
commit
ecfa6f1a12
6 changed files with 253 additions and 2 deletions
  1. +1
    -0
      parser/common/CMakeLists.txt
  2. +1
    -0
      parser/common/module.mk
  3. +202
    -0
      parser/common/parser_utils.cc
  4. +37
    -0
      parser/common/parser_utils.h
  5. +3
    -0
      parser/onnx/onnx_parser.cc
  6. +9
    -2
      parser/tensorflow/tensorflow_parser.cc

+ 1
- 0
parser/common/CMakeLists.txt View File

@@ -21,6 +21,7 @@ set(SRC_LIST
"op_map.cc"
"../../../ge/graph/passes/pass_manager.cc"
"../../../ge/common/thread_pool.cc"
"parser_utils.cc"
)

############ libparser_common.so ############


+ 1
- 0
parser/common/module.mk View File

@@ -32,6 +32,7 @@ COMMON_LOCAL_SRC_FILES := \
op_def/op_schema.cc \
op_def/operator.cc \
op_map.cc \
parser_utils.cc \

FMK_COMMON_SRC_FILES := \
../../common/types.cc \


+ 202
- 0
parser/common/parser_utils.cc View File

@@ -0,0 +1,202 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser_utils.h"
#include "external/ge/ge_api_types.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/util.h"
#include "framework/omg/parser/parser_types.h"
#include "graph/anchor.h"
#include "graph/compute_graph.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_adapter.h"
#include "graph/utils/op_desc_utils.h"
#include "register/op_registry.h"

namespace ge {
Status ParserUtils::ExpandOneToManyGraph(Graph &graph) {
GELOGD("Begin run ParserUtils::ExpandOneToManyGraph.");
for (const auto &gn : graph.GetDirectNode()) {
NodePtr n = NodeAdapter::GNode2Node(gn);
GE_CHECK_NOTNULL(n);
std::string ori_type;
(void)AttrUtils::GetStr(n->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, ori_type);
domi::ParseOpToGraphFunc parse_op_to_graph_func =
domi::OpRegistry::Instance()->GetParseOpToGraphFunc(n->GetType(), ori_type);
if (parse_op_to_graph_func == nullptr) {
GELOGD("node:%s type:%s ori type:%s has no parse_op_to_graph_func.",
n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str());
continue;
}
GELOGI("node:%s type:%s ori type:%s has registered one to many parser func.",
n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str());
Graph subgraph("one_to_many_graph");
Operator op = OpDescUtils::CreateOperatorFromNode(n);
Status ret = parse_op_to_graph_func(op, subgraph);
if (ret != SUCCESS) {
GELOGE(FAILED, "Get one to many graph failed for op:%s.", op.GetName().c_str());
return FAILED;
}
ret = ExpandNodeToSubgraph(subgraph, n, graph);
if (ret != SUCCESS) {
GELOGE(FAILED, "Expand one to many graph failed for op:%s.", op.GetName().c_str());
return FAILED;
}
}
GELOGD("run ParserUtils::ExpandOneToManyGraph success.");
return SUCCESS;
}

Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph) {
ComputeGraphPtr sub_compute_graph = GraphUtils::GetComputeGraph(subgraph);
GE_CHECK_NOTNULL(sub_compute_graph);
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);

// add subgraph node to graph.
std::unordered_map<std::string, NodePtr> all_new_nodes;
std::vector<NodePtr> input_nodes;
for (const auto &n : sub_compute_graph->GetDirectNode()) {
auto new_node = compute_graph->AddNode(n);
GE_CHECK_NOTNULL(new_node);
all_new_nodes[new_node->GetName()] = new_node;
if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Set owner graph for node:%s failed.", new_node->GetName().c_str());
return FAILED;
}

if (new_node->GetType() == ge::parser::DATA) {
input_nodes.emplace_back(new_node);
}
}

// handle input context.
Status ret = HandleInputContext(node, input_nodes, compute_graph);
if (ret != SUCCESS) {
GELOGE(FAILED, "run ParserUtils::HandleInputContext failed.");
return FAILED;
}

// handle output context.
std::vector<std::pair<NodePtr, int32_t>> out_node_index = sub_compute_graph->GetGraphOutNodesInfo();
ret = HandleOutputContext(node, out_node_index);
if (ret != SUCCESS) {
GELOGE(FAILED, "run ParserUtils::HandleOutputContext failed.");
return FAILED;
}

graphStatus graph_status = GraphUtils::RemoveNodeWithoutRelink(compute_graph, node);
if (graph_status != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove node:%s failed.", node->GetName().c_str());
return FAILED;
}
graph_status = compute_graph->TopologicalSorting();
if (graph_status != GRAPH_SUCCESS) {
GELOGE(FAILED, "Topological sorting failed.");
return FAILED;
}
return SUCCESS;
}

Status ParserUtils::HandleInputContext(const NodePtr &node,
const std::vector<NodePtr> &input_nodes,
const ComputeGraphPtr &compute_graph) {
GE_CHECK_NOTNULL(node);
for (const auto &in_n : input_nodes) {
GE_CHECK_NOTNULL(in_n);
int index;
if (!AttrUtils::GetInt(in_n->GetOpDesc(), ATTR_NAME_INDEX, index)) {
GELOGE(FAILED, "Get attr index of node:%s failed.", in_n->GetName().c_str());
return FAILED;
}
GELOGD("Begin to handle input node:%s with index:%d.", in_n->GetName().c_str(), index);
// get node's in data anchor and peer out anchor
auto node_in_anchor = node->GetInDataAnchor(index);
GE_CHECK_NOTNULL(node_in_anchor);
auto src_out_anchor = node_in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(src_out_anchor);
auto data_out_anchor = in_n->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(data_out_anchor);
for (const auto &peer_in_anchor : data_out_anchor->GetPeerInDataAnchors()) {
// add data edge
graphStatus ret = GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "remove data out anchor and peer in anchor failed.");
return FAILED;
}
ret = GraphUtils::RemoveEdge(src_out_anchor, node_in_anchor);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "remove node in anchor and peer out anchor failed.");
return FAILED;
}
ret = GraphUtils::AddEdge(src_out_anchor, peer_in_anchor);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "link node's peer out anchor and data's peer in anchor failed.");
return FAILED;
}

// add control edge
if (node->GetInControlAnchor() != nullptr) {
for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) {
graphStatus ret = GraphUtils::AddEdge(out_anchor, peer_in_anchor->GetOwnerNode()->GetInControlAnchor());
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "add control edge failed.");
return FAILED;
}
}
}
}
graphStatus ret = GraphUtils::RemoveNodeWithoutRelink(compute_graph, in_n);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "remove node:%s failed.", in_n->GetName().c_str());
return FAILED;
}
}
return SUCCESS;
}

Status ParserUtils::HandleOutputContext(const NodePtr &node,
const std::vector<std::pair<NodePtr, int32_t>> &out_node_index) {
GE_CHECK_NOTNULL(node);
GELOGD("The size of out node is %zu", out_node_index.size());
for (size_t index = 0; index < out_node_index.size(); index++) {
auto node_out_anchor = node->GetOutDataAnchor(index);
if (node_out_anchor == nullptr) {
continue;
}

NodePtr out_node = out_node_index[index].first;
int32_t out_index = out_node_index[index].second;
GELOGD("Begin to handle output node:%s[%zu] with index:%zu", out_node->GetName().c_str(), out_index, index);
auto src_out_anchor = out_node->GetOutDataAnchor(out_index); // get out node's out anchor.
GE_CHECK_NOTNULL(src_out_anchor);
for (const auto &dest_in_anchor : node_out_anchor->GetPeerInDataAnchors()) {
graphStatus ret = GraphUtils::RemoveEdge(node_out_anchor, dest_in_anchor);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "remove node's out anchor and peer in anchor failed.");
return FAILED;
}
ret = GraphUtils::AddEdge(src_out_anchor, dest_in_anchor);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "link node's peer out anchor and out node's out anchor failed.");
return FAILED;
}
}
}
return SUCCESS;
}
} // namespace ge

+ 37
- 0
parser/common/parser_utils.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_PARSER_UTILS_H_
#define PARSER_COMMON_PARSER_UTILS_H_

#include "graph/graph.h"
#include "graph/node.h"
#include "external/ge/ge_api_error_codes.h"

namespace ge {
class ParserUtils {
public:
static Status ExpandOneToManyGraph(Graph &graph);

private:
static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph);
static Status HandleInputContext(const NodePtr &node,
const std::vector<NodePtr> &input_nodes,
const ComputeGraphPtr &compute_graph);
static Status HandleOutputContext(const NodePtr &node, const std::vector<std::pair<NodePtr, int32_t>> &out_node_index);
};
} // namespace ge
#endif // PARSER_COMMON_PARSER_UTILS_H_

+ 3
- 0
parser/onnx/onnx_parser.cc View File

@@ -28,6 +28,7 @@
#include "onnx_util.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/pre_checker.h"
#include "parser/common/parser_utils.h"
#include "parser/onnx/onnx_util.h"
#include "register/op_registry.h"

@@ -533,6 +534,8 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
}
graph.SetInputs(input_ops).SetOutputs(output_indexs);

GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph));

UpdateFormat(graph);

GELOGI("Onnx model parser success.");


+ 9
- 2
parser/tensorflow/tensorflow_parser.cc View File

@@ -46,6 +46,7 @@
#include "parser/common/op_parser_factory.h"
#include "parser/common/pre_checker.h"
#include "parser/common/acl_graph_parser_util.h"
#include "parser/common/parser_utils.h"
#include "parser/tensorflow/tensorflow_fusion_op_parser.h"
#include "parser/tensorflow/tensorflow_fusionop_util.h"
#include "parser/tensorflow/tensorflow_op_parser.h"
@@ -167,8 +168,8 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque
auto i = subgraph_name_to_index.second;
auto subgraph_iname = op_desc->GetSubgraphInstanceName(i);
if (subgraph_iname.empty()) {
GELOGE(PARAM_INVALID, "The subgraph index %u of node %s is empty", i, node->GetName().c_str());
return PARAM_INVALID;
GELOGW("The subgraph index %u of node %s is empty", i, node->GetName().c_str());
continue;
}

// A function may be referenced multiple times in TF, change the graph name to ensure it is unique in GE
@@ -1396,6 +1397,8 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
DeleteFuisonNodeDef();

GE_RETURN_IF_ERROR(AddEdges(graph));
Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph));
GE_RETURN_IF_ERROR(RemoveIsolateNode(graph));
GE_RETURN_IF_ERROR(graph->TopologicalSorting());

@@ -2190,6 +2193,10 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto,
GELOGD("[TF Parser] Add framework node success");

ret = AddEdges(graph);

Graph dest_graph = GraphUtils::CreateGraphFromComputeGraph(graph);
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(dest_graph));

DeleteFuisonNodeDef();
GE_CHK_STATUS_EXEC(ret, return ret, "AddEdges failed");
GELOGD("[TF Parser] Add edges success");


Loading…
Cancel
Save