Browse Source

fix_consecutive_allreduce_bug

tags/v0.7.0-beta
yuchaojie 5 years ago
parent
commit
61bf4b18a2
2 changed files with 19 additions and 9 deletions
  1. +17
    -8
      mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc
  2. +2
    -1
      mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h

+ 17
- 8
mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc View File

@@ -74,9 +74,11 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k
MS_LOG(WARNING) << "Can't find send node place before AllReduce node.";
continue;
}
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter,
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)};
send_recv_pairs->push_back(pair1);
if (AnfAlgo::GetCNodeName(*mock_send_node_iter) != kAllReduceOpName) {
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter,
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)};
send_recv_pairs->push_back(pair1);
}
// Find node which uses AllReduce as input[0].
std::vector<CNodePtr>::iterator mock_recv_node_iter =
FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch);
@@ -84,9 +86,11 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k
MS_LOG(WARNING) << "Can't find recv node place after AllReduce node.";
return false;
}
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
IntToSize(mock_recv_node_iter - iter_begin)};
send_recv_pairs->push_back(pair2);
if (AnfAlgo::GetCNodeName(*mock_recv_node_iter) != kAllReduceOpName) {
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
IntToSize(mock_recv_node_iter - iter_begin)};
send_recv_pairs->push_back(pair2);
}
}
}
return true;
@@ -110,17 +114,22 @@ std::vector<CNodePtr>::iterator FindRecvNodePos(std::vector<CNodePtr>::iterator
std::vector<CNodePtr>::iterator end, const CNodePtr mock_send_node,
StreamSwitchType stream_switch_type) {
MS_EXCEPTION_IF_NULL(mock_send_node);
auto ret = end;
for (auto iter = begin; iter != end; iter++) {
auto node = *iter;
if (stream_switch_type == kAllReduceStreamSwitch) {
for (auto input : node->inputs()) {
if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) {
return iter;
if (AnfAlgo::GetCNodeName(node) != kAllReduceOpName) {
return iter;
} else if (ret == end) {
ret = iter;
}
}
}
}
}
return end;
return ret;
}

void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph,


+ 2
- 1
mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h View File

@@ -41,7 +41,8 @@ struct StreamSwitchNode {
if (offset < n.offset) {
return true;
} else if (offset == n.offset) {
return AnfAlgo::GetCNodeName(cnode) == kSendOpName ? true : false;
return (AnfAlgo::GetCNodeName(cnode) == kRecvOpName && AnfAlgo::GetCNodeName(n.cnode) == kSendOpName) ? false
: true;
} else {
return false;
}


Loading…
Cancel
Save