Browse Source

!21385 fix same node is used by two comm op

Merge pull request !21385 from zhoufeng/xiu-ba-ge
r1.4
i-robot Gitee 4 years ago
parent
commit
c1d65a3f76
3 changed files with 19 additions and 15 deletions
  1. +14
    -9
      mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_tensor_move_for_hccl_op.cc
  2. +1
    -1
      mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_tensor_move_for_hccl_op.h
  3. +4
    -5
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc

+ 14
- 9
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_tensor_move_for_hccl_op.cc View File

@@ -41,15 +41,18 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) {
}

// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) {
bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users, const CNodePtr &known_user,
size_t known_index) {
if (node_users.size() == 1) {
MS_LOG(INFO) << "This node only used once, no need to insert tensormove node.";
return false;
}
for (const auto &node_pair : node_users) {
auto node = node_pair.first;
if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::IsCommunicationOp(node)) {
MS_LOG(INFO) << "This node only used other real kernel: " << node->fullname_with_scope();
auto &node = node_pair.first;
size_t idx = IntToSize(node_pair.second);
if (AnfAlgo::IsRealKernel(node) && !(known_user == node && known_index == idx)) {
MS_LOG(INFO) << "User " << node->DebugString() << " idx " << idx << " is real kernel and diff with known "
<< known_user->DebugString() << " idx " << known_index;
return true;
}
}
@@ -58,11 +61,13 @@ bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) {
}
} // namespace

bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph, const AnfNodePtr &input,
const CNodePtr &cur_node) const {
bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph, const CNodePtr &cur_node,
size_t input_idx) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input);
MS_EXCEPTION_IF_NULL(cur_node);
auto input = cur_node->input(input_idx);
MS_EXCEPTION_IF_NULL(input);

if (IsPrimitiveCNode(cur_node, prim::kPrimReceive)) {
return false;
}
@@ -93,7 +98,7 @@ bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph,
MS_LOG(EXCEPTION) << "node has no output in manager"
<< " trace: " << trace::DumpSourceLines(input);
}
if (IsNodeOutPutUsedByOtherRealKernel(iter->second)) {
if (IsNodeOutPutUsedByOtherRealKernel(iter->second, cur_node, input_idx)) {
return true;
}
}
@@ -107,7 +112,7 @@ void InsertTensorMoveForHcclOp::InsertTensorMove(const FuncGraphPtr &graph, cons
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
for (size_t i = 1; i < hccl_node->size(); ++i) {
auto input = hccl_node->input(i);
if (NeedInsertTensorMove(graph, input, hccl_node)) {
if (NeedInsertTensorMove(graph, hccl_node, i)) {
auto tensor_move = CreateTensorMoveOp(graph, input);
if (tensor_move == nullptr) {
MS_LOG(EXCEPTION) << "Create tensor_move op failed.";


+ 1
- 1
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_tensor_move_for_hccl_op.h View File

@@ -32,7 +32,7 @@ class InsertTensorMoveForHcclOp : public PatternProcessPass {

private:
void InsertTensorMove(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const;
bool NeedInsertTensorMove(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &cur_node) const;
bool NeedInsertTensorMove(const FuncGraphPtr &graph, const CNodePtr &cur_node, size_t input_idx) const;
KernelQueryPtr kernel_query_;
};
} // namespace opt


+ 4
- 5
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -1575,16 +1575,15 @@ bool AnfRuntimeAlgorithm::IsInplaceNode(const mindspore::AnfNodePtr &kernel, con
}

bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
static const std::set<std::string> kCommunicationOpNames = {kAllReduceOpName, kAllGatherOpName, kBroadcastOpName,
kReduceScatterOpName, kHcomSendOpName, kReceiveOpName,
kAllToAllVOpName};
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto kernel_name = AnfAlgo::GetCNodeName(node);
if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName ||
kernel_name == kReduceScatterOpName || kernel_name == kHcomSendOpName || kernel_name == kReceiveOpName) {
return true;
}
return false;
return (kCommunicationOpNames.find(kernel_name) != kCommunicationOpNames.end());
}

bool AnfRuntimeAlgorithm::IsFusedCommunicationOp(const AnfNodePtr &node) {


Loading…
Cancel
Save