From ad69e1af151ed608c6278b8d63d5f6cbfc7b2159 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=8D=8E?= Date: Fri, 26 Feb 2021 10:34:14 +0800 Subject: [PATCH] parallel group --- ge/graph/passes/parallel_group_pass.cc | 312 +++++++++++++++++++++++++ ge/graph/passes/parallel_group_pass.h | 47 ++++ 2 files changed, 359 insertions(+) create mode 100644 ge/graph/passes/parallel_group_pass.cc create mode 100644 ge/graph/passes/parallel_group_pass.h diff --git a/ge/graph/passes/parallel_group_pass.cc b/ge/graph/passes/parallel_group_pass.cc new file mode 100644 index 00000000..b682f17a --- /dev/null +++ b/ge/graph/passes/parallel_group_pass.cc @@ -0,0 +1,312 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/parallel_group_pass.h" + +#include "framework/common/debug/ge_log.h" +#include "common/ge/ge_util.h" +#include "framework/common/ge_inner_error_codes.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" + +namespace ge { + +Status ParallelGroupPass::Run(ComputeGraphPtr graph) { + GELOGI("ParallelGroupPass running"); + if (graph == nullptr) { + GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group graph."); + return PARAM_INVALID; + } + + if (graph->GetParentGraph() != nullptr) { + GELOGW("Current graph %s is a subgraph, this pass only support root graph.", + graph->GetName().c_str()); + return GRAPH_SUCCESS; + } + + std::unordered_set parallel_group; + if (ProcessAllGraph(graph, parallel_group) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Process graph %s failed.", graph->GetName().c_str()); + return INTERNAL_ERROR; + } + + if (graph->TopologicalSorting() != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Graph topological sort failed."); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +Status ParallelGroupPass::ProcessAllGraph(ComputeGraphPtr graph, std::unordered_set ¶llel_group) { + if (graph == nullptr) { + GELOGE(PARAM_INVALID, "Input param graph is null, skip optimize parallel group subgraph."); + return PARAM_INVALID; + } + + std::map, NodePtr>> node_2_switch_merge; + if (ProcessSwitch(graph, node_2_switch_merge) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Input param graph is null, skip optimize parallel group subgraph."); + return GRAPH_FAILED; + } + + std::map> group_node; + const auto &candidates = graph->GetDirectNode(); + for (const auto &node : candidates) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + int group_id = -1; + if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARALLEL_GROUP, group_id)) { + bool is_unknown_shape = false; + auto ret = ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get node[%s] shape status failed!", node->GetName().c_str()); + return GRAPH_FAILED; + } + // only handle known shape node + if (!is_unknown_shape) { + group_node[group_id].push_back(node); + parallel_group.insert(group_id); + GELOGD("Find group node:%s, group_id=%d", node->GetName().c_str(), group_id); + } + } + + const auto &subgraph_name = op_desc->GetSubgraphInstanceNames(); + for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) { + const auto sub_graph = graph->GetSubgraph(*name_iter); + if (sub_graph != nullptr) { + std::unordered_set sub_parallel_group; + auto ret = ProcessAllGraph(sub_graph, sub_parallel_group); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Process sub graph %s failed.", sub_graph->GetName().c_str()); + return GRAPH_FAILED; + } + for (const auto &it : sub_parallel_group) { + parallel_group.insert(it); + group_node[it].emplace_back(node); + } + } + } + } + + for (const auto &itr : group_node) { + const auto &node_vec = itr.second; + if (node_vec.empty()) { + continue; + } + NodePtr pre_node = node_vec[0]; + NodePtr cur_node = nullptr; + for (int i = 1; i < node_vec.size(); i++) { + cur_node = node_vec[i]; + auto tmp_pre_node = pre_node; + auto tmp_cur_node = cur_node; + GELOGD("original add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), + cur_node->GetName().c_str()); + ReplaceSwitchAndMerge(tmp_pre_node, tmp_cur_node, node_2_switch_merge); + pre_node = cur_node; + } + } + + return GRAPH_SUCCESS; +} + +void ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { + if (pre_node == cur_node) { + return; + } + const auto &in_node = cur_node->GetInAllNodes(); + for (const auto &node : in_node) { + if (pre_node == node) { + GELOGD("node:%s and node:%s has linked", pre_node->GetName().c_str(), + cur_node->GetName().c_str()); + return; + } + } + auto pre_out_ctrl_anchor = pre_node->GetOutControlAnchor(); + auto cur_in_ctrl_anchor = cur_node->GetInControlAnchor(); + pre_out_ctrl_anchor->LinkTo(cur_in_ctrl_anchor); +} + +Status ParallelGroupPass::ProcessSwitch(ComputeGraphPtr graph, + std::map, + NodePtr>> &node_2_switch_merge) { + std::string type; + const auto &direct_node = graph->GetDirectNode(); + for (const auto &node : direct_node) { + auto ret = GetOriginalType(node, type); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + if ((type != SWITCH) && (type != REFSWITCH)) { + continue; + } + InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); + if (in_cond_anchor == nullptr) { + GELOGE(INTERNAL_ERROR, "in_cond_anchor is nullptr, node: %s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + OutDataAnchorPtr pred_cond_anchor = in_cond_anchor->GetPeerOutAnchor(); + if (pred_cond_anchor == nullptr) { + GELOGE(INTERNAL_ERROR, "pred_cond_anchor is nullptr, node: %s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + // ignore while op + if (pred_cond_anchor->GetOwnerNode()->GetType() == LOOPCOND) { + continue; + } + std::deque candidates; + std::vector merge_vec; + std::set group_set; + std::set visited; + candidates.emplace_back(node); + while (!candidates.empty()) { + NodePtr tmp_node = candidates.front(); + candidates.pop_front(); + for (const auto &out_node : tmp_node->GetOutAllNodes()) { + ret = GetOriginalType(out_node, type); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get node %s type failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + if ((type == MERGE) || (type == REFMERGE)) { + merge_vec.emplace_back(out_node); + continue; + } + const auto op = out_node->GetOpDesc(); + if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { + group_set.emplace(out_node); + } + if (visited.count(out_node) > 0) { + continue; + } + candidates.push_back(out_node); + visited.insert(out_node); + } + } + std::sort(merge_vec.begin(), merge_vec.end(), + [&direct_node] (NodePtr a, NodePtr b) -> bool { + return std::find(direct_node.begin(), direct_node.end(), a) < + std::find(direct_node.begin(), direct_node.end(), b); + }); + for (const auto &group_node : group_set) { + auto it = node_2_switch_merge.find(group_node); + if (it != node_2_switch_merge.end()) { + auto &tmp = it->second; + auto &switch_vec = tmp.first; + const auto &merge_node = tmp.second; + GELOGD("Find group node: %s in switch node %s and merge node :%s.", + group_node->GetName().c_str(), node->GetName().c_str(), merge_node->GetName().c_str()); + if (merge_node != merge_vec.back()) { + GELOGE(GRAPH_FAILED, "error: has two merge node: %s and %s.", + merge_node->GetName().c_str(), merge_vec.back()->GetName().c_str()); + return GRAPH_FAILED; + } + switch_vec.insert(node); + } else { + node_2_switch_merge.emplace(group_node, std::make_pair(std::set{node}, merge_vec.back())); + } + } + } + + return GRAPH_SUCCESS; +} + +Status ParallelGroupPass::GetOriginalType(const NodePtr &node, string &type) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "node is nullptr."); + return GRAPH_FAILED; + } + type = node->GetType(); + GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS); + if (node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "op_desc is nullptr."); + return GRAPH_FAILED; + } + bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); + if (!ret) { + GELOGE(INTERNAL_ERROR, "node is nullptr."); + return INTERNAL_ERROR; + } + GELOGD("Get FrameWorkOp original type [%s]", type.c_str()); + return SUCCESS; +} + +void ParallelGroupPass::ReplaceSwitchAndMerge(NodePtr &pre_node, + NodePtr &cur_node, + std::map, + NodePtr>> &node_2_switch_merge) { + auto pre_itr = node_2_switch_merge.find(pre_node); + auto cur_itr = node_2_switch_merge.find(cur_node); + if (pre_itr != node_2_switch_merge.end()) { + if (cur_itr != node_2_switch_merge.end()) { + const auto &pre_set = pre_itr->second.first; + const auto &cur_set = cur_itr->second.first; + if (!HasSameSwitch(pre_set, cur_set)) { + pre_node = pre_itr->second.second; + for (const auto &switch_node : cur_itr->second.first) { + AddCtrlEdge(pre_node, switch_node); + GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), + switch_node->GetName().c_str()); + } + } + } else { + pre_node = pre_itr->second.second; + AddCtrlEdge(pre_node, cur_node); + GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), + cur_node->GetName().c_str()); + } + } else { + if (cur_itr != node_2_switch_merge.end()) { + for (const auto &switch_node : cur_itr->second.first) { + int64_t pre_id = pre_node->GetOpDesc()->GetId(); + int64_t switch_id = switch_node->GetOpDesc()->GetId(); + if (pre_id > switch_id) { // special handle for merge and group node + auto merge_node = cur_itr->second.second; + AddCtrlEdge(merge_node, pre_node); + GELOGD("finally add ctrl anchor for node:%s-->node:%s", merge_node->GetName().c_str(), + pre_node->GetName().c_str()); + } else { + AddCtrlEdge(pre_node, switch_node); + GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), + switch_node->GetName().c_str()); + } + } + } else { + AddCtrlEdge(pre_node, cur_node); + GELOGD("finally add ctrl anchor for node:%s-->node:%s", pre_node->GetName().c_str(), + cur_node->GetName().c_str()); + } + } +} + +bool ParallelGroupPass::HasSameSwitch(const std::set &switch_set1, + const std::set &switch_set2) { + for (const auto &node1 : switch_set1) { + for (const auto &node2 : switch_set2) { + if (node1 == node2) { + return true; + } + } + } + return false; +} +} // namespace ge diff --git a/ge/graph/passes/parallel_group_pass.h b/ge/graph/passes/parallel_group_pass.h new file mode 100644 index 00000000..52ff0307 --- /dev/null +++ b/ge/graph/passes/parallel_group_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H +#define GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H + + +#include +#include +#include +#include "graph/graph.h" +#include "inc/graph_pass.h" + +namespace ge { +class ParallelGroupPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph) override; + private: + Status ProcessAllGraph(ComputeGraphPtr graph, std::unordered_set ¶llel_group); + void AddCtrlEdge(NodePtr pre_node, NodePtr cur_node); + Status GetOriginalType(const ge::NodePtr &node, std::string &type); + void ReplaceSwitchAndMerge(NodePtr &pre_node, + NodePtr &cur_node, + std::map, + NodePtr>> &node_2_switch_merge); + bool HasSameSwitch(const std::set &a, const std::set &b); + Status ProcessSwitch(ComputeGraphPtr graph, + std::map, + NodePtr>> &node_2_switch_merge); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H