Browse Source

Pre Merge pull request !954 from 陈叶朦/r1.2.0

pull/954/MERGE
陈叶朦 Gitee 5 years ago
parent
commit
d80e09f61b
7 changed files with 193 additions and 0 deletions
  1. +2
    -0
      ge/CMakeLists.txt
  2. +1
    -0
      ge/ge_inference.mk
  3. +1
    -0
      ge/ge_runner.mk
  4. +149
    -0
      ge/graph/passes/hccl_group_switch_copy_pass.cc
  5. +36
    -0
      ge/graph/passes/hccl_group_switch_copy_pass.h
  6. +3
    -0
      ge/graph/preprocess/graph_preprocess.cc
  7. +1
    -0
      tests/ut/ge/CMakeLists.txt

+ 2
- 0
ge/CMakeLists.txt View File

@@ -164,6 +164,7 @@ set(TRAIN_SRC_LIST
"graph/passes/dimension_compute_pass.cc"
"graph/passes/dropout_pass.cc"
"graph/passes/hccl_group_pass.cc"
"graph/passes/hccl_group_switch_copy_pass.cc"
"graph/passes/enter_pass.cc"
"graph/passes/assign_pass.cc"
"graph/passes/flow_ctrl_pass.cc"
@@ -542,6 +543,7 @@ set(INFER_SRC_LIST
"graph/passes/link_gen_mask_nodes_pass.cc"
"graph/passes/replace_with_empty_const_pass.cc"
"graph/passes/hccl_group_pass.cc"
"graph/passes/hccl_group_switch_copy_pass.cc"
"graph/passes/memcpy_addr_async_pass.cc"
"graph/passes/set_input_output_offset_pass.cc"
"graph/manager/model_manager/event_manager.cc"


+ 1
- 0
ge/ge_inference.mk View File

@@ -214,6 +214,7 @@ OMG_HOST_SRC_FILES := \
graph/passes/link_gen_mask_nodes_pass.cc \
graph/passes/replace_with_empty_const_pass.cc \
graph/passes/hccl_group_pass.cc \
graph/passes/hccl_group_switch_copy_pass.cc \
graph/passes/memcpy_addr_async_pass.cc \
graph/passes/set_input_output_offset_pass.cc \



+ 1
- 0
ge/ge_runner.mk View File

@@ -133,6 +133,7 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/dimension_compute_pass.cc \
graph/passes/dropout_pass.cc \
graph/passes/hccl_group_pass.cc \
graph/passes/hccl_group_switch_copy_pass.cc \
graph/passes/enter_pass.cc \
graph/passes/assign_pass.cc \
graph/passes/flow_ctrl_pass.cc \


+ 149
- 0
ge/graph/passes/hccl_group_switch_copy_pass.cc View File

@@ -0,0 +1,149 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "hccl_group_switch_copy_pass.h"
#include <stack>
#include "framework/common/debug/ge_log.h"
#include "framework/common/util.h"
#include "graph/common/omg_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "common/op/ge_op_utils.h"

namespace ge {
static const std::set<std::string> kEndTypes = { SWITCH, REFSWITCH, MERGE, REFMERGE, ENTER, REFENTER };
uint32_t HcclGroupSwitchCopyPass::copy_num_ = 0;

Status HcclGroupSwitchCopyPass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);
std::string type;
GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed.")
if ((type != SWITCH) && (type != REFSWITCH)) {
return SUCCESS;
}
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (op_desc->HasAttr(ATTR_NAME_HCCL_FUSED_GROUP)) {
return SUCCESS;
}
bool graph_change = false;
std::string group_name;
for (const auto &item : node->GetOutDataNodesAndAnchors()) {
const auto &out_node = item.first;
if (!IsHcclGroupMarked(out_node, group_name)) {
continue;
}
if (node->GetOutDataNodesSize() == 1) {
if (!AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_name)) {
GELOGE(INTERNAL_ERROR, "Set attr ATTR_NAME_HCCL_FUSED_GROUP failed.");
return INTERNAL_ERROR;
}
return SUCCESS;
}
if (CopySwitchNode(node, out_node, item.second, group_name) != SUCCESS) {
GELOGE(FAILED, "copy switch node for cur_switch=%s, out_node=%s failed",
node->GetName().c_str(), out_node->GetName().c_str());
return FAILED;
}
AddRePassNodesWithInOut(const_cast<NodePtr &>(out_node));
graph_change = true;
}
if (graph_change) {
AddRePassNode(node);
}

return SUCCESS;
}

bool HcclGroupSwitchCopyPass::IsHcclGroupMarked(const NodePtr &branch_head_node, std::string &group_name) {
group_name.clear();
std::stack<NodePtr> nodes;
nodes.push(branch_head_node);
std::set<NodePtr> visited_nodes;
while (!nodes.empty()) {
const NodePtr &cur_node = nodes.top();
nodes.pop();
if (visited_nodes.count(cur_node) > 0) {
continue;
}

std::string type;
GE_CHK_STATUS_RET(GetOriginalType(cur_node, type), "Get node type failed.")
if (kEndTypes.count(type) > 0) {
continue;
}
(void)AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_name);
if (!group_name.empty()) {
return true;
}
for (const auto &out_node : cur_node->GetOutAllNodes()) {
nodes.push(out_node);
}
visited_nodes.insert(cur_node);
}
return false;
}

Status HcclGroupSwitchCopyPass::CopySwitchNode(const NodePtr &switch_node,
const NodePtr &branch_head_node,
const InDataAnchorPtr &peer_in_anchor,
const std::string &group_name) {
const auto &copy_op_desc = AttrUtils::CopyOpDesc(switch_node->GetOpDesc());
GE_CHECK_NOTNULL(copy_op_desc);
copy_op_desc->SetName("CopySwitch_" + std::to_string(copy_num_++));
if (!AttrUtils::SetStr(copy_op_desc, ATTR_NAME_HCCL_FUSED_GROUP, group_name)) {
GELOGE(INTERNAL_ERROR, "Set attr ATTR_NAME_HCCL_FUSED_GROUP failed.");
return INTERNAL_ERROR;
}
const auto &copy_node = switch_node->GetOwnerComputeGraph()->AddNode(copy_op_desc);
GE_CHECK_NOTNULL(copy_node);

const auto &out_data_anchor = peer_in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(out_data_anchor);
if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) ||
(GraphUtils::AddEdge(copy_node->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor) != GRAPH_SUCCESS)) {
GELOGE(FAILED, "Replace data edge %s:%d->%s:%d by %s:%d->%s:%d failed.",
switch_node->GetName().c_str(), out_data_anchor->GetIdx(),
branch_head_node->GetName().c_str(), peer_in_anchor->GetIdx(),
copy_node->GetName().c_str(), out_data_anchor->GetIdx(),
branch_head_node->GetName().c_str(), peer_in_anchor->GetIdx());
return FAILED;
}
for (const auto &in_data_anchor : switch_node->GetAllInDataAnchors()) {
const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) { continue; }
if (GraphUtils::AddEdge(peer_out_anchor, copy_node->GetInDataAnchor(in_data_anchor->GetIdx())) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Add data edge %s:%d->%s:%d failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
peer_out_anchor->GetIdx(), copy_node->GetName().c_str(), in_data_anchor->GetIdx());
return FAILED;
}
}
for (const auto &in_node : switch_node->GetInControlNodes()) {
if (GraphUtils::AddEdge(in_node->GetOutControlAnchor(), copy_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Add ctrl edge %s->%s failed.", in_node->GetName().c_str(), copy_node->GetName().c_str());
return FAILED;
}
}
for (const auto &out_node : switch_node->GetOutControlNodes()) {
if (GraphUtils::AddEdge(copy_node->GetOutControlAnchor(), out_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Add ctrl edge %s->%s failed.", copy_node->GetName().c_str(), out_node->GetName().c_str());
return FAILED;
}
}

return SUCCESS;
}
} // namespace ge

+ 36
- 0
ge/graph/passes/hccl_group_switch_copy_pass.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_HCCL_GROUP_SWITCH_COPY_PASS_H_
#define GE_GRAPH_PASSES_HCCL_GROUP_SWITCH_COPY_PASS_H_

#include "graph/passes/base_pass.h"
namespace ge {
class HcclGroupSwitchCopyPass : public BaseNodePass {
public:
Status Run(NodePtr &node) override;
private:
static bool IsHcclGroupMarked(const NodePtr &branch_head_node, std::string &group_name);
static Status CopySwitchNode(const NodePtr &switch_node,
const NodePtr &branch_head_node,
const InDataAnchorPtr &peer_in_anchor,
const std::string &group_name);

static uint32_t copy_num_;
};
} // namespace ge

#endif // GE_GRAPH_PASSES_HCCL_GROUP_SWITCH_COPY_PASS_H_

+ 3
- 0
ge/graph/preprocess/graph_preprocess.cc View File

@@ -60,6 +60,7 @@
#include "graph/passes/get_original_format_pass.h"
#include "graph/passes/guarantee_const_pass.h"
#include "graph/passes/hccl_group_pass.h"
#include "graph/passes/hccl_group_switch_copy_pass.h"
#include "graph/passes/hccl_memcpy_pass.h"
#include "graph/passes/identity_pass.h"
#include "graph/passes/infershape_pass.h"
@@ -1453,8 +1454,10 @@ Status GraphPrepare::SwitchOpOptimize(ComputeGraphPtr &compute_graph) {
GEPass ge_passes(compute_graph);
NamesToPass hccl_group;
HcclGroupPass hccl_group_pass;
HcclGroupSwitchCopyPass hccl_group_switch_copy_pass;
GELOGD("Add hccl group pass success");
hccl_group.emplace_back("HcclGroupPass", &hccl_group_pass);
hccl_group.emplace_back("HcclGroupSwitchCopyPass", &hccl_group_switch_copy_pass);
auto ret = ge_passes.Run(hccl_group);
if (ret != SUCCESS) {
GELOGE(ret, "Run HcclGroupPass pass for preprocess failed, ret:%u.", ret);


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -243,6 +243,7 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/link_gen_mask_nodes_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/replace_with_empty_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/hccl_group_switch_copy_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"


Loading…
Cancel
Save