|
|
|
@@ -0,0 +1,149 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
* You may obtain a copy of the License at |
|
|
|
* |
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, software |
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "hccl_group_switch_copy_pass.h" |
|
|
|
#include <stack> |
|
|
|
#include "framework/common/debug/ge_log.h" |
|
|
|
#include "framework/common/util.h" |
|
|
|
#include "graph/common/omg_util.h" |
|
|
|
#include "graph/debug/ge_attr_define.h" |
|
|
|
#include "graph/utils/graph_utils.h" |
|
|
|
#include "common/op/ge_op_utils.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
static const std::set<std::string> kEndTypes = { SWITCH, REFSWITCH, MERGE, REFMERGE, ENTER, REFENTER }; |
|
|
|
uint32_t HcclGroupSwitchCopyPass::copy_num_ = 0; |
|
|
|
|
|
|
|
Status HcclGroupSwitchCopyPass::Run(NodePtr &node) { |
|
|
|
GE_CHECK_NOTNULL(node); |
|
|
|
std::string type; |
|
|
|
GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed.") |
|
|
|
if ((type != SWITCH) && (type != REFSWITCH)) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
OpDescPtr op_desc = node->GetOpDesc(); |
|
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
|
if (op_desc->HasAttr(ATTR_NAME_HCCL_FUSED_GROUP)) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
bool graph_change = false; |
|
|
|
std::string group_name; |
|
|
|
for (const auto &item : node->GetOutDataNodesAndAnchors()) { |
|
|
|
const auto &out_node = item.first; |
|
|
|
if (!IsHcclGroupMarked(out_node, group_name)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (node->GetOutDataNodesSize() == 1) { |
|
|
|
if (!AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_name)) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Set attr ATTR_NAME_HCCL_FUSED_GROUP failed."); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
if (CopySwitchNode(node, out_node, item.second, group_name) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "copy switch node for cur_switch=%s, out_node=%s failed", |
|
|
|
node->GetName().c_str(), out_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
AddRePassNodesWithInOut(const_cast<NodePtr &>(out_node)); |
|
|
|
graph_change = true; |
|
|
|
} |
|
|
|
if (graph_change) { |
|
|
|
AddRePassNode(node); |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
bool HcclGroupSwitchCopyPass::IsHcclGroupMarked(const NodePtr &branch_head_node, std::string &group_name) { |
|
|
|
group_name.clear(); |
|
|
|
std::stack<NodePtr> nodes; |
|
|
|
nodes.push(branch_head_node); |
|
|
|
std::set<NodePtr> visited_nodes; |
|
|
|
while (!nodes.empty()) { |
|
|
|
const NodePtr &cur_node = nodes.top(); |
|
|
|
nodes.pop(); |
|
|
|
if (visited_nodes.count(cur_node) > 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
std::string type; |
|
|
|
GE_CHK_STATUS_RET(GetOriginalType(cur_node, type), "Get node type failed.") |
|
|
|
if (kEndTypes.count(type) > 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
(void)AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_name); |
|
|
|
if (!group_name.empty()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
for (const auto &out_node : cur_node->GetOutAllNodes()) { |
|
|
|
nodes.push(out_node); |
|
|
|
} |
|
|
|
visited_nodes.insert(cur_node); |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
Status HcclGroupSwitchCopyPass::CopySwitchNode(const NodePtr &switch_node, |
|
|
|
const NodePtr &branch_head_node, |
|
|
|
const InDataAnchorPtr &peer_in_anchor, |
|
|
|
const std::string &group_name) { |
|
|
|
const auto ©_op_desc = AttrUtils::CopyOpDesc(switch_node->GetOpDesc()); |
|
|
|
GE_CHECK_NOTNULL(copy_op_desc); |
|
|
|
copy_op_desc->SetName("CopySwitch_" + std::to_string(copy_num_++)); |
|
|
|
if (!AttrUtils::SetStr(copy_op_desc, ATTR_NAME_HCCL_FUSED_GROUP, group_name)) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Set attr ATTR_NAME_HCCL_FUSED_GROUP failed."); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
const auto ©_node = switch_node->GetOwnerComputeGraph()->AddNode(copy_op_desc); |
|
|
|
GE_CHECK_NOTNULL(copy_node); |
|
|
|
|
|
|
|
const auto &out_data_anchor = peer_in_anchor->GetPeerOutAnchor(); |
|
|
|
GE_CHECK_NOTNULL(out_data_anchor); |
|
|
|
if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) || |
|
|
|
(GraphUtils::AddEdge(copy_node->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor) != GRAPH_SUCCESS)) { |
|
|
|
GELOGE(FAILED, "Replace data edge %s:%d->%s:%d by %s:%d->%s:%d failed.", |
|
|
|
switch_node->GetName().c_str(), out_data_anchor->GetIdx(), |
|
|
|
branch_head_node->GetName().c_str(), peer_in_anchor->GetIdx(), |
|
|
|
copy_node->GetName().c_str(), out_data_anchor->GetIdx(), |
|
|
|
branch_head_node->GetName().c_str(), peer_in_anchor->GetIdx()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
for (const auto &in_data_anchor : switch_node->GetAllInDataAnchors()) { |
|
|
|
const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); |
|
|
|
if (peer_out_anchor == nullptr) { continue; } |
|
|
|
if (GraphUtils::AddEdge(peer_out_anchor, copy_node->GetInDataAnchor(in_data_anchor->GetIdx())) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(FAILED, "Add data edge %s:%d->%s:%d failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), |
|
|
|
peer_out_anchor->GetIdx(), copy_node->GetName().c_str(), in_data_anchor->GetIdx()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
for (const auto &in_node : switch_node->GetInControlNodes()) { |
|
|
|
if (GraphUtils::AddEdge(in_node->GetOutControlAnchor(), copy_node->GetInControlAnchor()) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(FAILED, "Add ctrl edge %s->%s failed.", in_node->GetName().c_str(), copy_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
for (const auto &out_node : switch_node->GetOutControlNodes()) { |
|
|
|
if (GraphUtils::AddEdge(copy_node->GetOutControlAnchor(), out_node->GetInControlAnchor()) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(FAILED, "Add ctrl edge %s->%s failed.", copy_node->GetName().c_str(), out_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} // namespace ge |