|
|
|
@@ -0,0 +1,158 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
* You may obtain a copy of the License at |
|
|
|
* |
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, software |
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "graph/passes/fuse_same_data_pass.h" |
|
|
|
|
|
|
|
#include <map> |
|
|
|
#include <memory> |
|
|
|
#include <string> |
|
|
|
#include <vector> |
|
|
|
#include "graph/utils/op_desc_utils.h" |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
#include "graph/utils/node_utils.h" |
|
|
|
|
|
|
|
using std::map; |
|
|
|
using std::vector; |
|
|
|
using std::string; |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
Status FuseSameDataPass::Run(ge::ComputeGraphPtr graph) { |
|
|
|
if (graph == nullptr) { |
|
|
|
GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); |
|
|
|
return GE_GRAPH_PARAM_NULLPTR; |
|
|
|
} |
|
|
|
GELOGI("FuseSameDataPass in."); |
|
|
|
map<ComputeGraphPtr, map<NodePtr, vector<NodePtr>>> need_fuse_nodes; |
|
|
|
GetFuseDataNodes(graph, need_fuse_nodes); |
|
|
|
|
|
|
|
return FuseDataNodes(graph, need_fuse_nodes); |
|
|
|
} |
|
|
|
|
|
|
|
void FuseSameDataPass::GetFuseDataNodes(ComputeGraphPtr &graph, |
|
|
|
map<ComputeGraphPtr, map<NodePtr, vector<NodePtr>>> &need_fuse_nodes) { |
|
|
|
// need_fuse_nodes is [fused_const_node, {need to be fused data nodes}] |
|
|
|
for (auto &node : graph->GetAllNodes()) { |
|
|
|
auto op_desc = node->GetOpDesc(); |
|
|
|
GE_IF_BOOL_EXEC(op_desc == nullptr, continue); |
|
|
|
if (node->GetType() != DATA) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
const ComputeGraphPtr &sub_graph = node->GetOwnerComputeGraph(); |
|
|
|
const NodePtr &parent_node = NodeUtils::GetParentInput(node); |
|
|
|
if (sub_graph != nullptr && parent_node != nullptr) { |
|
|
|
if (parent_node->GetType() == CONSTANT || parent_node->GetType() == CONSTANTOP) { |
|
|
|
GELOGD("Data node %s exist in %s, parent node is %s, parent node type is %s.", node->GetName().c_str(), |
|
|
|
sub_graph->GetName().c_str(), parent_node->GetName().c_str(), parent_node->GetType().c_str()); |
|
|
|
if (need_fuse_nodes.find(sub_graph) != need_fuse_nodes.end()) { |
|
|
|
// data node in same graph |
|
|
|
auto &const_data_map = need_fuse_nodes[sub_graph]; |
|
|
|
if (const_data_map.find(parent_node) == const_data_map.end()) { |
|
|
|
const_data_map[parent_node] = {node}; |
|
|
|
} else { |
|
|
|
const_data_map[parent_node].emplace_back(node); |
|
|
|
} |
|
|
|
need_fuse_nodes[sub_graph] = const_data_map; |
|
|
|
} else { |
|
|
|
// data node in different graph |
|
|
|
map<NodePtr, vector<NodePtr>> const_data_map; |
|
|
|
const_data_map[parent_node] = {node}; |
|
|
|
need_fuse_nodes[sub_graph] = const_data_map; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
GELOGD("Need fuse data exist %zu sub graph.", need_fuse_nodes.size()); |
|
|
|
} |
|
|
|
|
|
|
|
Status FuseSameDataPass::FuseDataNodes(ComputeGraphPtr &graph, |
|
|
|
map<ComputeGraphPtr, map<NodePtr, vector<NodePtr>>> &need_fuse_nodes) { |
|
|
|
for (auto &sub_graph_ : need_fuse_nodes) { |
|
|
|
auto &const_data_map = sub_graph->second; |
|
|
|
for (auto &iter : const_data_map) { |
|
|
|
auto nodes = iter->second; |
|
|
|
auto first_node = nodes.at(0); |
|
|
|
size_t len = nodes.size(); |
|
|
|
GELOGD("Need to fuse %zu data nodes in %s, their parent node is %s.", len, graph->GetName().c_str(), |
|
|
|
iter->first->GetName().c_str()); |
|
|
|
for (size_t i = 1; i < len; ++i) { |
|
|
|
auto node = nodes.at(i); |
|
|
|
GELOGI("Replace redundant data node %s by %s exist in graph: %s.", node->GetName().c_str(), |
|
|
|
first_node->GetName().c_str(), graph->GetName().c_str()); |
|
|
|
// the data node which can be fused has none input(both data and control in) |
|
|
|
if (GraphUtils::MoveOutCtrlEdges(node, first_node) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (MoveOutDataEdges(node, first_node) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "[%s] RemoveNodeWithoutRelink failed.", node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
// src_node will be replaced by dst_node |
|
|
|
Status FuseSameDataPass::MoveOutDataEdges(NodePtr &src_node, NodePtr &dst_node) { |
|
|
|
// key is node_name-in_index |
|
|
|
std::map<string, InDataAnchorPtr> src_out_node_to_indexs; |
|
|
|
GetOutDataNodeToIndexMap(src_node, src_out_node_to_indexs); |
|
|
|
if (src_out_node_to_indexs.empty()) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
std::map<string, InDataAnchorPtr> dst_out_node_to_indexs; |
|
|
|
GetOutDataNodeToIndexMap(dst_node, dst_out_node_to_indexs); |
|
|
|
|
|
|
|
auto dst_out_data_anchor = dst_node->GetOutDataAnchor(0); |
|
|
|
GE_CHECK_NOTNULL(dst_out_data_anchor); |
|
|
|
auto src_out_data_anchor = src_node->GetOutDataAnchor(0); |
|
|
|
GE_CHECK_NOTNULL(src_out_data_anchor); |
|
|
|
src_out_data_anchor->UnlinkAll(); |
|
|
|
for (auto it = src_out_node_to_indexs.begin(); it != src_out_node_to_indexs.end(); ++it) { |
|
|
|
if (dst_out_node_to_indexs.count(it->first) > 0) { |
|
|
|
continue; // exclusion of duplication |
|
|
|
} |
|
|
|
auto ret = dst_out_data_anchor->LinkTo(it->second); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Failed to move out data edge from %s to %s", src_node->GetName().c_str(), |
|
|
|
dst_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
void FuseSameDataPass::GetOutDataNodeToIndexMap(NodePtr &node, std::map<string, InDataAnchorPtr> &out_node_to_indexs) { |
|
|
|
auto out_data_anchor = node->GetOutDataAnchor(0); |
|
|
|
GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor); |
|
|
|
auto peer_in_anchors = out_data_anchor->GetPeerInDataAnchors(); |
|
|
|
if (!peer_in_anchors.empty()) { |
|
|
|
for (auto &anchor : peer_in_anchors) { |
|
|
|
int index = anchor->GetIdx(); |
|
|
|
NodePtr out_node = anchor->GetOwnerNode(); |
|
|
|
if (out_node == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
string key_name = out_node->GetName() + "-" + std::to_string(index); |
|
|
|
out_node_to_indexs[key_name] = anchor; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace ge |