diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 8d9edb65..e4418619 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -157,6 +157,7 @@ set(TRAIN_SRC_LIST "graph/passes/compile_nodes_pass.cc" "graph/passes/constant_folding_pass.cc" "graph/passes/constant_fuse_same_pass.cc" + "graph/passes/fuse_same_data_pass.cc" "graph/passes/remove_same_const_pass.cc" "graph/passes/useless_control_out_remove_pass.cc" "graph/passes/control_trigger_pass.cc" @@ -439,6 +440,7 @@ set(INFER_SRC_LIST "graph/passes/net_output_pass.cc" "graph/passes/replace_transshape_pass.cc" "graph/passes/constant_fuse_same_pass.cc" + "graph/passes/fuse_same_data_pass.cc" "graph/passes/print_op_pass.cc" "graph/passes/no_use_reshape_remove_pass.cc" "graph/passes/iterator_op_pass.cc" diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index 74d09404..79e59b30 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -103,6 +103,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/net_output_pass.cc \ graph/passes/replace_transshape_pass.cc \ graph/passes/constant_fuse_same_pass.cc \ + graph/passes/fuse_same_data_pass.cc \ graph/passes/print_op_pass.cc \ graph/passes/no_use_reshape_remove_pass.cc \ graph/passes/iterator_op_pass.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 5a99dc8c..8f193c71 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -127,6 +127,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/compile_nodes_pass.cc \ graph/passes/constant_folding_pass.cc \ graph/passes/constant_fuse_same_pass.cc \ + graph/passes/fuse_same_data_pass.cc \ graph/passes/remove_same_const_pass.cc \ graph/passes/useless_control_out_remove_pass.cc \ graph/passes/control_trigger_pass.cc \ diff --git a/ge/graph/load/new_model_manager/zero_copy_offset.h b/ge/graph/load/new_model_manager/zero_copy_offset.h index 8ead742d..66fcd887 100644 --- a/ge/graph/load/new_model_manager/zero_copy_offset.h +++ b/ge/graph/load/new_model_manager/zero_copy_offset.h @@ -65,7 +65,7 @@ class ZeroCopyOffset { // data_size of Data/Netoutput int64_t GetDataSize() const { return data_size_; } // value of *outside_addrs_ from davinci_model - std::vector>> &GetOutsideAddrs() { return outside_addrs_; } + const std::vector>> &GetOutsideAddrs() { return outside_addrs_; } // name of op std::string GetOpName() const { return op_name_; } diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index c4f91036..d3f066f1 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -55,6 +55,7 @@ #include "graph/passes/dimension_adjust_pass.h" #include "graph/passes/dimension_compute_pass.h" #include "graph/passes/flow_ctrl_pass.h" +#include "graph/passes/fuse_same_data_pass.h" #include "graph/passes/identity_pass.h" #include "graph/passes/input_output_connection_identify_pass.h" #include "graph/passes/iterator_op_pass.h" @@ -2112,6 +2113,8 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass)); GE_CHK_STATUS_RET( after_merge_passes.AddPass("OptimizeStage1_1::ConstantFuseSamePass", new (std::nothrow) ConstantFuseSamePass)); + GE_CHK_STATUS_RET( + after_merge_passes.AddPass("OptimizeStage1_1::FuseSameDataPass", new (std::nothrow) FuseSameDataPass)); GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::CommonSubexpressionEliminationPass", new (std::nothrow) CommonSubexpressionEliminationPass)); GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::PermutePass", new (std::nothrow) PermutePass)) diff --git a/ge/graph/passes/fuse_same_data_pass.cc b/ge/graph/passes/fuse_same_data_pass.cc new file mode 100644 index 00000000..aaacf672 --- /dev/null +++ b/ge/graph/passes/fuse_same_data_pass.cc @@ -0,0 +1,158 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/fuse_same_data_pass.h" + +#include +#include +#include +#include +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/utils/node_utils.h" + +using std::map; +using std::vector; +using std::string; + +namespace ge { +Status FuseSameDataPass::Run(ge::ComputeGraphPtr graph) { + if (graph == nullptr) { + GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); + return GE_GRAPH_PARAM_NULLPTR; + } + GELOGI("FuseSameDataPass in."); + map>> need_fuse_nodes; + GetFuseDataNodes(graph, need_fuse_nodes); + + return FuseDataNodes(graph, need_fuse_nodes); +} + +void FuseSameDataPass::GetFuseDataNodes(ComputeGraphPtr &graph, + map>> &need_fuse_nodes) { + // need_fuse_nodes is [fused_const_node, {need to be fused data nodes}] + for (auto &node : graph->GetAllNodes()) { + auto op_desc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(op_desc == nullptr, continue); + if (node->GetType() != DATA) { + continue; + } + const ComputeGraphPtr &sub_graph = node->GetOwnerComputeGraph(); + const NodePtr &parent_node = NodeUtils::GetParentInput(node); + if (sub_graph != nullptr && parent_node != nullptr) { + if (parent_node->GetType() == CONSTANT || parent_node->GetType() == CONSTANTOP) { + GELOGD("Data node %s exist in %s, parent node is %s, parent node type is %s.", node->GetName().c_str(), + sub_graph->GetName().c_str(), parent_node->GetName().c_str(), parent_node->GetType().c_str()); + if (need_fuse_nodes.find(sub_graph) != need_fuse_nodes.end()) { + // data node in same graph + auto &const_data_map = need_fuse_nodes[sub_graph]; + if (const_data_map.find(parent_node) == const_data_map.end()) { + const_data_map[parent_node] = {node}; + } else { + const_data_map[parent_node].emplace_back(node); + } + need_fuse_nodes[sub_graph] = const_data_map; + } else { + // data node in different graph + map> const_data_map; + const_data_map[parent_node] = {node}; + need_fuse_nodes[sub_graph] = const_data_map; + } + } + } + } + GELOGD("Need fuse data exist %zu sub graph.", need_fuse_nodes.size()); +} + +Status FuseSameDataPass::FuseDataNodes(ComputeGraphPtr &graph, + map>> &need_fuse_nodes) { + for (auto &sub_graph_ : need_fuse_nodes) { + auto &const_data_map = sub_graph->second; + for (auto &iter : const_data_map) { + auto nodes = iter->second; + auto first_node = nodes.at(0); + size_t len = nodes.size(); + GELOGD("Need to fuse %zu data nodes in %s, their parent node is %s.", len, graph->GetName().c_str(), + iter->first->GetName().c_str()); + for (size_t i = 1; i < len; ++i) { + auto node = nodes.at(i); + GELOGI("Replace redundant data node %s by %s exist in graph: %s.", node->GetName().c_str(), + first_node->GetName().c_str(), graph->GetName().c_str()); + // the data node which can be fused has none input(both data and control in) + if (GraphUtils::MoveOutCtrlEdges(node, first_node) != SUCCESS) { + return FAILED; + } + if (MoveOutDataEdges(node, first_node) != SUCCESS) { + return FAILED; + } + if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) { + GELOGE(FAILED, "[%s] RemoveNodeWithoutRelink failed.", node->GetName().c_str()); + return FAILED; + } + } + } + } + return SUCCESS; +} + +// src_node will be replaced by dst_node +Status FuseSameDataPass::MoveOutDataEdges(NodePtr &src_node, NodePtr &dst_node) { + // key is node_name-in_index + std::map src_out_node_to_indexs; + GetOutDataNodeToIndexMap(src_node, src_out_node_to_indexs); + if (src_out_node_to_indexs.empty()) { + return SUCCESS; + } + + std::map dst_out_node_to_indexs; + GetOutDataNodeToIndexMap(dst_node, dst_out_node_to_indexs); + + auto dst_out_data_anchor = dst_node->GetOutDataAnchor(0); + GE_CHECK_NOTNULL(dst_out_data_anchor); + auto src_out_data_anchor = src_node->GetOutDataAnchor(0); + GE_CHECK_NOTNULL(src_out_data_anchor); + src_out_data_anchor->UnlinkAll(); + for (auto it = src_out_node_to_indexs.begin(); it != src_out_node_to_indexs.end(); ++it) { + if (dst_out_node_to_indexs.count(it->first) > 0) { + continue; // exclusion of duplication + } + auto ret = dst_out_data_anchor->LinkTo(it->second); + if (ret != SUCCESS) { + GELOGE(FAILED, "Failed to move out data edge from %s to %s", src_node->GetName().c_str(), + dst_node->GetName().c_str()); + return FAILED; + } + } + return SUCCESS; +} + +void FuseSameDataPass::GetOutDataNodeToIndexMap(NodePtr &node, std::map &out_node_to_indexs) { + auto out_data_anchor = node->GetOutDataAnchor(0); + GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor); + auto peer_in_anchors = out_data_anchor->GetPeerInDataAnchors(); + if (!peer_in_anchors.empty()) { + for (auto &anchor : peer_in_anchors) { + int index = anchor->GetIdx(); + NodePtr out_node = anchor->GetOwnerNode(); + if (out_node == nullptr) { + continue; + } + string key_name = out_node->GetName() + "-" + std::to_string(index); + out_node_to_indexs[key_name] = anchor; + } + } +} +} // namespace ge diff --git a/ge/graph/passes/fuse_same_data_pass.h b/ge/graph/passes/fuse_same_data_pass.h new file mode 100755 index 00000000..4f46948a --- /dev/null +++ b/ge/graph/passes/fuse_same_data_pass.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_FUSE_SAME_DATA_PASS_H_ +#define GE_GRAPH_PASSES_FUSE_SAME_DATA_PASS_H_ + +#include +#include +#include +#include "graph/types.h" +#include "inc/graph_pass.h" + +namespace ge { +class FuseSameDataPass : public GraphPass { + public: + Status Run(ge::ComputeGraphPtr graph) override; + + private: + void GetFuseDataNodes(ComputeGraphPtr &graph, + std::map>> &need_fuse_nodes); + Status FuseDataNodes(ComputeGraphPtr &graph, + std::map>> &need_fuse_nodes); + Status MoveOutDataEdges(NodePtr &src_node, NodePtr &dst_node); + void GetOutDataNodeToIndexMap(NodePtr &node, std::map &out_node_to_indexs); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_FUSE_SAME_DATA_PASS_H_