From 6298acb98a66cc6c1c43dc7237aec5b249997ff4 Mon Sep 17 00:00:00 2001 From: chenyemeng Date: Fri, 15 Jan 2021 17:02:19 +0800 Subject: [PATCH] add HcclGroupSwitchCopyPass --- ge/CMakeLists.txt | 2 + ge/ge_inference.mk | 1 + ge/ge_runner.mk | 1 + .../passes/hccl_group_switch_copy_pass.cc | 149 ++++++++++++++++++ ge/graph/passes/hccl_group_switch_copy_pass.h | 36 +++++ ge/graph/preprocess/graph_preprocess.cc | 3 + tests/ut/ge/CMakeLists.txt | 1 + 7 files changed, 193 insertions(+) create mode 100644 ge/graph/passes/hccl_group_switch_copy_pass.cc create mode 100644 ge/graph/passes/hccl_group_switch_copy_pass.h diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 88d74730..79441d20 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -164,6 +164,7 @@ set(TRAIN_SRC_LIST "graph/passes/dimension_compute_pass.cc" "graph/passes/dropout_pass.cc" "graph/passes/hccl_group_pass.cc" + "graph/passes/hccl_group_switch_copy_pass.cc" "graph/passes/enter_pass.cc" "graph/passes/assign_pass.cc" "graph/passes/flow_ctrl_pass.cc" @@ -542,6 +543,7 @@ set(INFER_SRC_LIST "graph/passes/link_gen_mask_nodes_pass.cc" "graph/passes/replace_with_empty_const_pass.cc" "graph/passes/hccl_group_pass.cc" + "graph/passes/hccl_group_switch_copy_pass.cc" "graph/passes/memcpy_addr_async_pass.cc" "graph/passes/set_input_output_offset_pass.cc" "graph/manager/model_manager/event_manager.cc" diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk index e20456d5..fa590249 100755 --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -214,6 +214,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/link_gen_mask_nodes_pass.cc \ graph/passes/replace_with_empty_const_pass.cc \ graph/passes/hccl_group_pass.cc \ + graph/passes/hccl_group_switch_copy_pass.cc \ graph/passes/memcpy_addr_async_pass.cc \ graph/passes/set_input_output_offset_pass.cc \ diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index 9706dadb..a9b5b12a 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -133,6 +133,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/dimension_compute_pass.cc \ graph/passes/dropout_pass.cc \ graph/passes/hccl_group_pass.cc \ + graph/passes/hccl_group_switch_copy_pass.cc \ graph/passes/enter_pass.cc \ graph/passes/assign_pass.cc \ graph/passes/flow_ctrl_pass.cc \ diff --git a/ge/graph/passes/hccl_group_switch_copy_pass.cc b/ge/graph/passes/hccl_group_switch_copy_pass.cc new file mode 100644 index 00000000..b501a57b --- /dev/null +++ b/ge/graph/passes/hccl_group_switch_copy_pass.cc @@ -0,0 +1,149 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hccl_group_switch_copy_pass.h" +#include +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/common/omg_util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "common/op/ge_op_utils.h" + +namespace ge { +static const std::set kEndTypes = { SWITCH, REFSWITCH, MERGE, REFMERGE, ENTER, REFENTER }; +uint32_t HcclGroupSwitchCopyPass::copy_num_ = 0; + +Status HcclGroupSwitchCopyPass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + std::string type; + GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed.") + if ((type != SWITCH) && (type != REFSWITCH)) { + return SUCCESS; + } + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (op_desc->HasAttr(ATTR_NAME_HCCL_FUSED_GROUP)) { + return SUCCESS; + } + bool graph_change = false; + std::string group_name; + for (const auto &item : node->GetOutDataNodesAndAnchors()) { + const auto &out_node = item.first; + if (!IsHcclGroupMarked(out_node, group_name)) { + continue; + } + if (node->GetOutDataNodesSize() == 1) { + if (!AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_name)) { + GELOGE(INTERNAL_ERROR, "Set attr ATTR_NAME_HCCL_FUSED_GROUP failed."); + return INTERNAL_ERROR; + } + return SUCCESS; + } + if (CopySwitchNode(node, out_node, item.second, group_name) != SUCCESS) { + GELOGE(FAILED, "copy switch node for cur_switch=%s, out_node=%s failed", + node->GetName().c_str(), out_node->GetName().c_str()); + return FAILED; + } + AddRePassNodesWithInOut(const_cast(out_node)); + graph_change = true; + } + if (graph_change) { + AddRePassNode(node); + } + + return SUCCESS; +} + +bool HcclGroupSwitchCopyPass::IsHcclGroupMarked(const NodePtr &branch_head_node, std::string &group_name) { + group_name.clear(); + std::stack nodes; + nodes.push(branch_head_node); + std::set visited_nodes; + while (!nodes.empty()) { + const NodePtr &cur_node = nodes.top(); + nodes.pop(); + if (visited_nodes.count(cur_node) > 0) { + continue; + } + + std::string type; + GE_CHK_STATUS_RET(GetOriginalType(cur_node, type), "Get node type failed.") + if (kEndTypes.count(type) > 0) { + continue; + } + (void)AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_name); + if (!group_name.empty()) { + return true; + } + for (const auto &out_node : cur_node->GetOutAllNodes()) { + nodes.push(out_node); + } + visited_nodes.insert(cur_node); + } + return false; +} + +Status HcclGroupSwitchCopyPass::CopySwitchNode(const NodePtr &switch_node, + const NodePtr &branch_head_node, + const InDataAnchorPtr &peer_in_anchor, + const std::string &group_name) { + const auto ©_op_desc = AttrUtils::CopyOpDesc(switch_node->GetOpDesc()); + GE_CHECK_NOTNULL(copy_op_desc); + copy_op_desc->SetName("CopySwitch_" + std::to_string(copy_num_++)); + if (!AttrUtils::SetStr(copy_op_desc, ATTR_NAME_HCCL_FUSED_GROUP, group_name)) { + GELOGE(INTERNAL_ERROR, "Set attr ATTR_NAME_HCCL_FUSED_GROUP failed."); + return INTERNAL_ERROR; + } + const auto ©_node = switch_node->GetOwnerComputeGraph()->AddNode(copy_op_desc); + GE_CHECK_NOTNULL(copy_node); + + const auto &out_data_anchor = peer_in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(out_data_anchor); + if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) || + (GraphUtils::AddEdge(copy_node->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor) != GRAPH_SUCCESS)) { + GELOGE(FAILED, "Replace data edge %s:%d->%s:%d by %s:%d->%s:%d failed.", + switch_node->GetName().c_str(), out_data_anchor->GetIdx(), + branch_head_node->GetName().c_str(), peer_in_anchor->GetIdx(), + copy_node->GetName().c_str(), out_data_anchor->GetIdx(), + branch_head_node->GetName().c_str(), peer_in_anchor->GetIdx()); + return FAILED; + } + for (const auto &in_data_anchor : switch_node->GetAllInDataAnchors()) { + const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { continue; } + if (GraphUtils::AddEdge(peer_out_anchor, copy_node->GetInDataAnchor(in_data_anchor->GetIdx())) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add data edge %s:%d->%s:%d failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), + peer_out_anchor->GetIdx(), copy_node->GetName().c_str(), in_data_anchor->GetIdx()); + return FAILED; + } + } + for (const auto &in_node : switch_node->GetInControlNodes()) { + if (GraphUtils::AddEdge(in_node->GetOutControlAnchor(), copy_node->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add ctrl edge %s->%s failed.", in_node->GetName().c_str(), copy_node->GetName().c_str()); + return FAILED; + } + } + for (const auto &out_node : switch_node->GetOutControlNodes()) { + if (GraphUtils::AddEdge(copy_node->GetOutControlAnchor(), out_node->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add ctrl edge %s->%s failed.", copy_node->GetName().c_str(), out_node->GetName().c_str()); + return FAILED; + } + } + + return SUCCESS; +} +} // namespace ge diff --git a/ge/graph/passes/hccl_group_switch_copy_pass.h b/ge/graph/passes/hccl_group_switch_copy_pass.h new file mode 100644 index 00000000..27a21605 --- /dev/null +++ b/ge/graph/passes/hccl_group_switch_copy_pass.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_HCCL_GROUP_SWITCH_COPY_PASS_H_ +#define GE_GRAPH_PASSES_HCCL_GROUP_SWITCH_COPY_PASS_H_ + +#include "graph/passes/base_pass.h" +namespace ge { +class HcclGroupSwitchCopyPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; + private: + static bool IsHcclGroupMarked(const NodePtr &branch_head_node, std::string &group_name); + static Status CopySwitchNode(const NodePtr &switch_node, + const NodePtr &branch_head_node, + const InDataAnchorPtr &peer_in_anchor, + const std::string &group_name); + + static uint32_t copy_num_; +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_HCCL_GROUP_SWITCH_COPY_PASS_H_ diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 27c54f03..0fcaa53f 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -60,6 +60,7 @@ #include "graph/passes/get_original_format_pass.h" #include "graph/passes/guarantee_const_pass.h" #include "graph/passes/hccl_group_pass.h" +#include "graph/passes/hccl_group_switch_copy_pass.h" #include "graph/passes/hccl_memcpy_pass.h" #include "graph/passes/identity_pass.h" #include "graph/passes/infershape_pass.h" @@ -1453,8 +1454,10 @@ Status GraphPrepare::SwitchOpOptimize(ComputeGraphPtr &compute_graph) { GEPass ge_passes(compute_graph); NamesToPass hccl_group; HcclGroupPass hccl_group_pass; + HcclGroupSwitchCopyPass hccl_group_switch_copy_pass; GELOGD("Add hccl group pass success"); hccl_group.emplace_back("HcclGroupPass", &hccl_group_pass); + hccl_group.emplace_back("HcclGroupSwitchCopyPass", &hccl_group_switch_copy_pass); auto ret = ge_passes.Run(hccl_group); if (ret != SUCCESS) { GELOGE(ret, "Run HcclGroupPass pass for preprocess failed, ret:%u.", ret); diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index e4b8d8d2..31eaffa1 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -243,6 +243,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/link_gen_mask_nodes_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/replace_with_empty_const_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/hccl_group_switch_copy_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"