Browse Source

add fuse same data pass for online infer dynamic

pull/882/head
zhou_lili 5 years ago
parent
commit
219e81f10e
7 changed files with 206 additions and 1 deletions
  1. +2
    -0
      ge/CMakeLists.txt
  2. +1
    -0
      ge/ge_inference.mk
  3. +1
    -0
      ge/ge_runner.mk
  4. +1
    -1
      ge/graph/load/new_model_manager/zero_copy_offset.h
  5. +3
    -0
      ge/graph/manager/graph_manager.cc
  6. +158
    -0
      ge/graph/passes/fuse_same_data_pass.cc
  7. +40
    -0
      ge/graph/passes/fuse_same_data_pass.h

+ 2
- 0
ge/CMakeLists.txt View File

@@ -157,6 +157,7 @@ set(TRAIN_SRC_LIST
"graph/passes/compile_nodes_pass.cc"
"graph/passes/constant_folding_pass.cc"
"graph/passes/constant_fuse_same_pass.cc"
"graph/passes/fuse_same_data_pass.cc"
"graph/passes/remove_same_const_pass.cc"
"graph/passes/useless_control_out_remove_pass.cc"
"graph/passes/control_trigger_pass.cc"
@@ -439,6 +440,7 @@ set(INFER_SRC_LIST
"graph/passes/net_output_pass.cc"
"graph/passes/replace_transshape_pass.cc"
"graph/passes/constant_fuse_same_pass.cc"
"graph/passes/fuse_same_data_pass.cc"
"graph/passes/print_op_pass.cc"
"graph/passes/no_use_reshape_remove_pass.cc"
"graph/passes/iterator_op_pass.cc"


+ 1
- 0
ge/ge_inference.mk View File

@@ -103,6 +103,7 @@ OMG_HOST_SRC_FILES := \
graph/passes/net_output_pass.cc \
graph/passes/replace_transshape_pass.cc \
graph/passes/constant_fuse_same_pass.cc \
graph/passes/fuse_same_data_pass.cc \
graph/passes/print_op_pass.cc \
graph/passes/no_use_reshape_remove_pass.cc \
graph/passes/iterator_op_pass.cc \


+ 1
- 0
ge/ge_runner.mk View File

@@ -127,6 +127,7 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/compile_nodes_pass.cc \
graph/passes/constant_folding_pass.cc \
graph/passes/constant_fuse_same_pass.cc \
graph/passes/fuse_same_data_pass.cc \
graph/passes/remove_same_const_pass.cc \
graph/passes/useless_control_out_remove_pass.cc \
graph/passes/control_trigger_pass.cc \


+ 1
- 1
ge/graph/load/new_model_manager/zero_copy_offset.h View File

@@ -65,7 +65,7 @@ class ZeroCopyOffset {
// data_size of Data/Netoutput
int64_t GetDataSize() const { return data_size_; }
// value of *outside_addrs_ from davinci_model
std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; }
const std::vector<std::map<const void *, std::vector<void *>>> &GetOutsideAddrs() { return outside_addrs_; }
// name of op
std::string GetOpName() const { return op_name_; }



+ 3
- 0
ge/graph/manager/graph_manager.cc View File

@@ -55,6 +55,7 @@
#include "graph/passes/dimension_adjust_pass.h"
#include "graph/passes/dimension_compute_pass.h"
#include "graph/passes/flow_ctrl_pass.h"
#include "graph/passes/fuse_same_data_pass.h"
#include "graph/passes/identity_pass.h"
#include "graph/passes/input_output_connection_identify_pass.h"
#include "graph/passes/iterator_op_pass.h"
@@ -2112,6 +2113,8 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass));
GE_CHK_STATUS_RET(
after_merge_passes.AddPass("OptimizeStage1_1::ConstantFuseSamePass", new (std::nothrow) ConstantFuseSamePass));
GE_CHK_STATUS_RET(
after_merge_passes.AddPass("OptimizeStage1_1::FuseSameDataPass", new (std::nothrow) FuseSameDataPass));
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CommonSubexpressionEliminationPass",
new (std::nothrow) CommonSubexpressionEliminationPass));
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::PermutePass", new (std::nothrow) PermutePass))


+ 158
- 0
ge/graph/passes/fuse_same_data_pass.cc View File

@@ -0,0 +1,158 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/fuse_same_data_pass.h"

#include <map>
#include <memory>
#include <string>
#include <vector>
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/utils/node_utils.h"

using std::map;
using std::vector;
using std::string;

namespace ge {
Status FuseSameDataPass::Run(ge::ComputeGraphPtr graph) {
if (graph == nullptr) {
GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null.");
return GE_GRAPH_PARAM_NULLPTR;
}
GELOGI("FuseSameDataPass in.");
map<ComputeGraphPtr, map<NodePtr, vector<NodePtr>>> need_fuse_nodes;
GetFuseDataNodes(graph, need_fuse_nodes);

return FuseDataNodes(graph, need_fuse_nodes);
}

void FuseSameDataPass::GetFuseDataNodes(ComputeGraphPtr &graph,
map<ComputeGraphPtr, map<NodePtr, vector<NodePtr>>> &need_fuse_nodes) {
// need_fuse_nodes is [fused_const_node, {need to be fused data nodes}]
for (auto &node : graph->GetAllNodes()) {
auto op_desc = node->GetOpDesc();
GE_IF_BOOL_EXEC(op_desc == nullptr, continue);
if (node->GetType() != DATA) {
continue;
}
const ComputeGraphPtr &sub_graph = node->GetOwnerComputeGraph();
const NodePtr &parent_node = NodeUtils::GetParentInput(node);
if (sub_graph != nullptr && parent_node != nullptr) {
if (parent_node->GetType() == CONSTANT || parent_node->GetType() == CONSTANTOP) {
GELOGD("Data node %s exist in %s, parent node is %s, parent node type is %s.", node->GetName().c_str(),
sub_graph->GetName().c_str(), parent_node->GetName().c_str(), parent_node->GetType().c_str());
if (need_fuse_nodes.find(sub_graph) != need_fuse_nodes.end()) {
// data node in same graph
auto &const_data_map = need_fuse_nodes[sub_graph];
if (const_data_map.find(parent_node) == const_data_map.end()) {
const_data_map[parent_node] = {node};
} else {
const_data_map[parent_node].emplace_back(node);
}
need_fuse_nodes[sub_graph] = const_data_map;
} else {
// data node in different graph
map<NodePtr, vector<NodePtr>> const_data_map;
const_data_map[parent_node] = {node};
need_fuse_nodes[sub_graph] = const_data_map;
}
}
}
}
GELOGD("Need fuse data exist %zu sub graph.", need_fuse_nodes.size());
}

Status FuseSameDataPass::FuseDataNodes(ComputeGraphPtr &graph,
map<ComputeGraphPtr, map<NodePtr, vector<NodePtr>>> &need_fuse_nodes) {
for (auto &sub_graph_ : need_fuse_nodes) {
auto &const_data_map = sub_graph->second;
for (auto &iter : const_data_map) {
auto nodes = iter->second;
auto first_node = nodes.at(0);
size_t len = nodes.size();
GELOGD("Need to fuse %zu data nodes in %s, their parent node is %s.", len, graph->GetName().c_str(),
iter->first->GetName().c_str());
for (size_t i = 1; i < len; ++i) {
auto node = nodes.at(i);
GELOGI("Replace redundant data node %s by %s exist in graph: %s.", node->GetName().c_str(),
first_node->GetName().c_str(), graph->GetName().c_str());
// the data node which can be fused has none input(both data and control in)
if (GraphUtils::MoveOutCtrlEdges(node, first_node) != SUCCESS) {
return FAILED;
}
if (MoveOutDataEdges(node, first_node) != SUCCESS) {
return FAILED;
}
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) {
GELOGE(FAILED, "[%s] RemoveNodeWithoutRelink failed.", node->GetName().c_str());
return FAILED;
}
}
}
}
return SUCCESS;
}

// src_node will be replaced by dst_node
Status FuseSameDataPass::MoveOutDataEdges(NodePtr &src_node, NodePtr &dst_node) {
// key is node_name-in_index
std::map<string, InDataAnchorPtr> src_out_node_to_indexs;
GetOutDataNodeToIndexMap(src_node, src_out_node_to_indexs);
if (src_out_node_to_indexs.empty()) {
return SUCCESS;
}

std::map<string, InDataAnchorPtr> dst_out_node_to_indexs;
GetOutDataNodeToIndexMap(dst_node, dst_out_node_to_indexs);

auto dst_out_data_anchor = dst_node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(dst_out_data_anchor);
auto src_out_data_anchor = src_node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(src_out_data_anchor);
src_out_data_anchor->UnlinkAll();
for (auto it = src_out_node_to_indexs.begin(); it != src_out_node_to_indexs.end(); ++it) {
if (dst_out_node_to_indexs.count(it->first) > 0) {
continue; // exclusion of duplication
}
auto ret = dst_out_data_anchor->LinkTo(it->second);
if (ret != SUCCESS) {
GELOGE(FAILED, "Failed to move out data edge from %s to %s", src_node->GetName().c_str(),
dst_node->GetName().c_str());
return FAILED;
}
}
return SUCCESS;
}

void FuseSameDataPass::GetOutDataNodeToIndexMap(NodePtr &node, std::map<string, InDataAnchorPtr> &out_node_to_indexs) {
auto out_data_anchor = node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor);
auto peer_in_anchors = out_data_anchor->GetPeerInDataAnchors();
if (!peer_in_anchors.empty()) {
for (auto &anchor : peer_in_anchors) {
int index = anchor->GetIdx();
NodePtr out_node = anchor->GetOwnerNode();
if (out_node == nullptr) {
continue;
}
string key_name = out_node->GetName() + "-" + std::to_string(index);
out_node_to_indexs[key_name] = anchor;
}
}
}
} // namespace ge

+ 40
- 0
ge/graph/passes/fuse_same_data_pass.h View File

@@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_FUSE_SAME_DATA_PASS_H_
#define GE_GRAPH_PASSES_FUSE_SAME_DATA_PASS_H_

#include <map>
#include <set>
#include <vector>
#include "graph/types.h"
#include "inc/graph_pass.h"

namespace ge {
class FuseSameDataPass : public GraphPass {
public:
Status Run(ge::ComputeGraphPtr graph) override;

private:
void GetFuseDataNodes(ComputeGraphPtr &graph,
std::map<ComputeGraphPtr, std::map<NodePtr, std::vector<NodePtr>>> &need_fuse_nodes);
Status FuseDataNodes(ComputeGraphPtr &graph,
std::map<ComputeGraphPtr, std::map<NodePtr, std::vector<NodePtr>>> &need_fuse_nodes);
Status MoveOutDataEdges(NodePtr &src_node, NodePtr &dst_node);
void GetOutDataNodeToIndexMap(NodePtr &node, std::map<string, InDataAnchorPtr> &out_node_to_indexs);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_FUSE_SAME_DATA_PASS_H_

Loading…
Cancel
Save