|
|
|
@@ -0,0 +1,312 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
* You may obtain a copy of the License at |
|
|
|
* |
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, software |
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "graph/passes/parallel_group_pass.h" |
|
|
|
|
|
|
|
#include "framework/common/debug/ge_log.h" |
|
|
|
#include "common/ge/ge_util.h" |
|
|
|
#include "framework/common/ge_inner_error_codes.h" |
|
|
|
#include "graph/debug/ge_attr_define.h" |
|
|
|
#include "graph/utils/graph_utils.h" |
|
|
|
#include "graph/utils/node_utils.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
|
|
|
|
Status ParallelGroupPass::Run(ComputeGraphPtr graph) { |
|
|
|
GELOGI("ParallelGroupPass running"); |
|
|
|
if (graph == nullptr) { |
|
|
|
GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group graph."); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
if (graph->GetParentGraph() != nullptr) { |
|
|
|
GELOGW("Current graph %s is a subgraph, this pass only support root graph.", |
|
|
|
graph->GetName().c_str()); |
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
std::unordered_set<int> parallel_group; |
|
|
|
if (ProcessAllGraph(graph, parallel_group) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Process graph %s failed.", graph->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
if (graph->TopologicalSorting() != GRAPH_SUCCESS) { |
|
|
|
GELOGE(GRAPH_FAILED, "Graph topological sort failed."); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_set<int> ¶llel_group) { |
|
|
|
if (graph == nullptr) { |
|
|
|
GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group subgraph."); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
std::map<NodePtr, std::pair<std::set<NodePtr>, NodePtr>> node_2_switch_merge; |
|
|
|
if (ProcessSwitch(graph, node_2_switch_merge) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(GRAPH_FAILED, "Input param graph is null, skip optimize parallel group subgraph."); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
std::map<int, vector<NodePtr>> group_node; |
|
|
|
const auto &candidates = graph->GetDirectNode(); |
|
|
|
for (const auto &node : candidates) { |
|
|
|
OpDescPtr op_desc = node->GetOpDesc(); |
|
|
|
if (op_desc == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
int group_id = -1; |
|
|
|
if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id)) { |
|
|
|
bool is_unknown_shape = false; |
|
|
|
auto ret = ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape); |
|
|
|
if (ret != GRAPH_SUCCESS) { |
|
|
|
GELOGE(GRAPH_FAILED, "Get node[%s] shape status failed!", node->GetName().c_str()); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
// only handle known shape node |
|
|
|
if (!is_unknown_shape) { |
|
|
|
group_node[group_id].push_back(node); |
|
|
|
parallel_group.insert(group_id); |
|
|
|
GELOGD("Find group node:%s, group_id=%d", node->GetName().c_str(), group_id); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
const auto &subgraph_name = op_desc->GetSubgraphInstanceNames(); |
|
|
|
for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) { |
|
|
|
const auto sub_graph = graph->GetSubgraph(*name_iter); |
|
|
|
if (sub_graph != nullptr) { |
|
|
|
std::unordered_set<int> sub_parallel_group; |
|
|
|
auto ret = ProcessAllGraph(sub_graph, sub_parallel_group); |
|
|
|
if (ret != GRAPH_SUCCESS) { |
|
|
|
GELOGE(GRAPH_FAILED, "Process sub graph %s failed.", sub_graph->GetName().c_str()); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
for (const auto &it : sub_parallel_group) { |
|
|
|
parallel_group.insert(it); |
|
|
|
group_node[it].emplace_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &itr : group_node) { |
|
|
|
const auto &node_vec = itr.second; |
|
|
|
if (node_vec.empty()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
NodePtr pre_node = node_vec[0]; |
|
|
|
NodePtr cur_node = nullptr; |
|
|
|
for (int i = 1; i < node_vec.size(); i++) { |
|
|
|
cur_node = node_vec[i]; |
|
|
|
auto tmp_pre_node = pre_node; |
|
|
|
auto tmp_cur_node = cur_node; |
|
|
|
GELOGD("original add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), |
|
|
|
cur_node->GetName().c_str()); |
|
|
|
ReplaceSwitchAndMerge(tmp_pre_node, tmp_cur_node, node_2_switch_merge); |
|
|
|
pre_node = cur_node; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
void ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { |
|
|
|
if (pre_node == cur_node) { |
|
|
|
return; |
|
|
|
} |
|
|
|
const auto &in_node = cur_node->GetInAllNodes(); |
|
|
|
for (const auto &node : in_node) { |
|
|
|
if (pre_node == node) { |
|
|
|
GELOGD("node:%s and node:%s has linked", pre_node->GetName().c_str(), |
|
|
|
cur_node->GetName().c_str()); |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
auto pre_out_ctrl_anchor = pre_node->GetOutControlAnchor(); |
|
|
|
auto cur_in_ctrl_anchor = cur_node->GetInControlAnchor(); |
|
|
|
pre_out_ctrl_anchor->LinkTo(cur_in_ctrl_anchor); |
|
|
|
} |
|
|
|
|
|
|
|
Status ParallelGroupPass::ProcessSwitch(ComputeGraphPtr graph, |
|
|
|
std::map<NodePtr, |
|
|
|
std::pair<std::set<NodePtr>, |
|
|
|
NodePtr>> &node_2_switch_merge) { |
|
|
|
std::string type; |
|
|
|
const auto &direct_node = graph->GetDirectNode(); |
|
|
|
for (const auto &node : direct_node) { |
|
|
|
auto ret = GetOriginalType(node, type); |
|
|
|
if (ret != GRAPH_SUCCESS) { |
|
|
|
GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str()); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
if ((type != SWITCH) && (type != REFSWITCH)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); |
|
|
|
if (in_cond_anchor == nullptr) { |
|
|
|
GELOGE(INTERNAL_ERROR, "in_cond_anchor is nullptr, node: %s.", node->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
OutDataAnchorPtr pred_cond_anchor = in_cond_anchor->GetPeerOutAnchor(); |
|
|
|
if (pred_cond_anchor == nullptr) { |
|
|
|
GELOGE(INTERNAL_ERROR, "pred_cond_anchor is nullptr, node: %s.", node->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
// ignore while op |
|
|
|
if (pred_cond_anchor->GetOwnerNode()->GetType() == LOOPCOND) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::deque<NodePtr> candidates; |
|
|
|
std::vector<NodePtr> merge_vec; |
|
|
|
std::set<NodePtr> group_set; |
|
|
|
std::set<NodePtr> visited; |
|
|
|
candidates.emplace_back(node); |
|
|
|
while (!candidates.empty()) { |
|
|
|
NodePtr tmp_node = candidates.front(); |
|
|
|
candidates.pop_front(); |
|
|
|
for (const auto &out_node : tmp_node->GetOutAllNodes()) { |
|
|
|
ret = GetOriginalType(out_node, type); |
|
|
|
if (ret != GRAPH_SUCCESS) { |
|
|
|
GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str()); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
if ((type == MERGE) || (type == REFMERGE)) { |
|
|
|
merge_vec.emplace_back(out_node); |
|
|
|
continue; |
|
|
|
} |
|
|
|
const auto op = out_node->GetOpDesc(); |
|
|
|
if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { |
|
|
|
group_set.emplace(out_node); |
|
|
|
} |
|
|
|
if (visited.count(out_node) > 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
candidates.push_back(out_node); |
|
|
|
visited.insert(out_node); |
|
|
|
} |
|
|
|
} |
|
|
|
std::sort(merge_vec.begin(), merge_vec.end(), |
|
|
|
[&direct_node] (NodePtr a, NodePtr b) -> bool { |
|
|
|
return std::find(direct_node.begin(), direct_node.end(), a) < |
|
|
|
std::find(direct_node.begin(), direct_node.end(), b); |
|
|
|
}); |
|
|
|
for (const auto &group_node : group_set) { |
|
|
|
auto it = node_2_switch_merge.find(group_node); |
|
|
|
if (it != node_2_switch_merge.end()) { |
|
|
|
auto &tmp = it->second; |
|
|
|
auto &switch_vec = tmp.first; |
|
|
|
const auto &merge_node = tmp.second; |
|
|
|
GELOGD("Find group node: %s in switch node %s and merge node :%s.", |
|
|
|
group_node->GetName().c_str(), node->GetName().c_str(), merge_node->GetName().c_str()); |
|
|
|
if (merge_node != merge_vec.back()) { |
|
|
|
GELOGE(GRAPH_FAILED, "error: has two merge node: %s and %s.", |
|
|
|
merge_node->GetName().c_str(), merge_vec.back()->GetName().c_str()); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
switch_vec.insert(node); |
|
|
|
} else { |
|
|
|
node_2_switch_merge.emplace(group_node, std::make_pair(std::set<NodePtr>{node}, merge_vec.back())); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status ParallelGroupPass::GetOriginalType(const NodePtr &node, string &type) { |
|
|
|
if (node == nullptr) { |
|
|
|
GELOGE(GRAPH_FAILED, "node is nullptr."); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
type = node->GetType(); |
|
|
|
GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS); |
|
|
|
if (node->GetOpDesc() == nullptr) { |
|
|
|
GELOGE(GRAPH_FAILED, "op_desc is nullptr."); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); |
|
|
|
if (!ret) { |
|
|
|
GELOGE(INTERNAL_ERROR, "node is nullptr."); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
GELOGD("Get FrameWorkOp original type [%s]", type.c_str()); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
void ParallelGroupPass::ReplaceSwitchAndMerge(NodePtr &pre_node, |
|
|
|
NodePtr &cur_node, |
|
|
|
std::map<NodePtr, |
|
|
|
std::pair<std::set<NodePtr>, |
|
|
|
NodePtr>> &node_2_switch_merge) { |
|
|
|
auto pre_itr = node_2_switch_merge.find(pre_node); |
|
|
|
auto cur_itr = node_2_switch_merge.find(cur_node); |
|
|
|
if (pre_itr != node_2_switch_merge.end()) { |
|
|
|
if (cur_itr != node_2_switch_merge.end()) { |
|
|
|
const auto &pre_set = pre_itr->second.first; |
|
|
|
const auto &cur_set = cur_itr->second.first; |
|
|
|
if (!HasSameSwitch(pre_set, cur_set)) { |
|
|
|
pre_node = pre_itr->second.second; |
|
|
|
for (const auto &switch_node : cur_itr->second.first) { |
|
|
|
AddCtrlEdge(pre_node, switch_node); |
|
|
|
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), |
|
|
|
switch_node->GetName().c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
pre_node = pre_itr->second.second; |
|
|
|
AddCtrlEdge(pre_node, cur_node); |
|
|
|
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), |
|
|
|
cur_node->GetName().c_str()); |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (cur_itr != node_2_switch_merge.end()) { |
|
|
|
for (const auto &switch_node : cur_itr->second.first) { |
|
|
|
int64_t pre_id = pre_node->GetOpDesc()->GetId(); |
|
|
|
int64_t switch_id = switch_node->GetOpDesc()->GetId(); |
|
|
|
if (pre_id > switch_id) { // special handle for merge and group node |
|
|
|
auto merge_node = cur_itr->second.second; |
|
|
|
AddCtrlEdge(merge_node, pre_node); |
|
|
|
GELOGD("finally add ctrl anchor for node:%s-->node:%s", merge_node->GetName().c_str(), |
|
|
|
pre_node->GetName().c_str()); |
|
|
|
} else { |
|
|
|
AddCtrlEdge(pre_node, switch_node); |
|
|
|
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), |
|
|
|
switch_node->GetName().c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
AddCtrlEdge(pre_node, cur_node); |
|
|
|
GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), |
|
|
|
cur_node->GetName().c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool ParallelGroupPass::HasSameSwitch(const std::set<NodePtr> &switch_set1, |
|
|
|
const std::set<NodePtr> &switch_set2) { |
|
|
|
for (const auto &node1 : switch_set1) { |
|
|
|
for (const auto &node2 : switch_set2) { |
|
|
|
if (node1 == node2) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
} // namespace ge |