|
|
|
@@ -71,28 +71,32 @@ bool FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &k |
|
|
|
std::vector<CNodePtr>::iterator mock_send_node_iter = |
|
|
|
FindSendNodePos(iter_begin, iter + 1, *iter, kAllReduceStreamSwitch); |
|
|
|
if (mock_send_node_iter == iter + 1) { |
|
|
|
MS_LOG(WARNING) << "Can't find send node place before AllReduce node."; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (AnfAlgo::GetCNodeName(*mock_send_node_iter) != kAllReduceOpName) { |
|
|
|
MS_LOG(INFO) << "Can't find send node place before AllReduce node."; |
|
|
|
} else if (AnfAlgo::GetCNodeName(*mock_send_node_iter) != kAllReduceOpName) { |
|
|
|
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter, |
|
|
|
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)}; |
|
|
|
send_recv_pairs->push_back(pair1); |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "mock_send_node is AllReduce, no need to add stream switch node."; |
|
|
|
} |
|
|
|
// Find node which uses AllReduce as input[0]. |
|
|
|
std::vector<CNodePtr>::iterator mock_recv_node_iter = |
|
|
|
FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch); |
|
|
|
if (mock_recv_node_iter == iter_end) { |
|
|
|
MS_LOG(WARNING) << "Can't find recv node place after AllReduce node."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (AnfAlgo::GetCNodeName(*mock_recv_node_iter) != kAllReduceOpName) { |
|
|
|
MS_LOG(INFO) << "Can't find recv node place after AllReduce node."; |
|
|
|
} else if (AnfAlgo::GetCNodeName(*mock_recv_node_iter) != kAllReduceOpName) { |
|
|
|
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1), |
|
|
|
IntToSize(mock_recv_node_iter - iter_begin)}; |
|
|
|
send_recv_pairs->push_back(pair2); |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "mock_recv_node is AllReduce, no need to add stream switch node."; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (send_recv_pairs->empty()) { |
|
|
|
MS_LOG(INFO) << "No stream switch node is found."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
|