|
|
@@ -291,6 +291,74 @@ void AscendStreamAssign::FindAllReduceParallel(const shared_ptr<session::KernelG |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) { |
|
|
|
|
|
MS_LOG(INFO) << "start"; |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_ptr); |
|
|
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order(); |
|
|
|
|
|
vector<uint32_t> fusion_hcom_index; |
|
|
|
|
|
vector<CNodePtr> orders; |
|
|
|
|
|
for (size_t i = 0; i < cnode_ptr_list.size(); i++) { |
|
|
|
|
|
auto cur_cnode = cnode_ptr_list[i]; |
|
|
|
|
|
if (IsHcom(cur_cnode)) { |
|
|
|
|
|
fusion_hcom_index.emplace_back(i); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (fusion_hcom_index.size() < 2) { |
|
|
|
|
|
MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them"; |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
uint32_t first_index = fusion_hcom_index[0]; |
|
|
|
|
|
uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1]; |
|
|
|
|
|
|
|
|
|
|
|
uint32_t cur_event_id = total_event_num_; |
|
|
|
|
|
uint32_t pre_hcom_stream_id = UINT32_MAX; |
|
|
|
|
|
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders)); |
|
|
|
|
|
for (size_t i = first_index; i <= last_index; i++) { |
|
|
|
|
|
auto cur_cnode = cnode_ptr_list[i]; |
|
|
|
|
|
auto it = std::find(fusion_hcom_index.begin(), fusion_hcom_index.end(), i); |
|
|
|
|
|
if (it == fusion_hcom_index.end()) { |
|
|
|
|
|
orders.emplace_back(cur_cnode); |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode); |
|
|
|
|
|
if (cur_hcom_stream_id == pre_hcom_stream_id) { |
|
|
|
|
|
orders.emplace_back(cur_cnode); |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (i == first_index) { |
|
|
|
|
|
// first fusion hcom |
|
|
|
|
|
orders.emplace_back(cur_cnode); |
|
|
|
|
|
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); |
|
|
|
|
|
orders.emplace_back(send); |
|
|
|
|
|
} else if (i == last_index) { |
|
|
|
|
|
// last fusion hcom |
|
|
|
|
|
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); |
|
|
|
|
|
orders.emplace_back(recv); |
|
|
|
|
|
orders.emplace_back(cur_cnode); |
|
|
|
|
|
cur_event_id++; |
|
|
|
|
|
} else { |
|
|
|
|
|
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); |
|
|
|
|
|
orders.emplace_back(recv); |
|
|
|
|
|
cur_event_id++; |
|
|
|
|
|
orders.emplace_back(cur_cnode); |
|
|
|
|
|
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); |
|
|
|
|
|
orders.emplace_back(send); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
pre_hcom_stream_id = cur_hcom_stream_id; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders)); |
|
|
|
|
|
graph_ptr->set_execution_order(orders); |
|
|
|
|
|
total_event_num_ = cur_event_id; |
|
|
|
|
|
MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]"; |
|
|
|
|
|
MS_LOG(INFO) << "end"; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) { |
|
|
void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) { |
|
|
MS_LOG(INFO) << "start"; |
|
|
MS_LOG(INFO) << "start"; |
|
|
MS_EXCEPTION_IF_NULL(graph_ptr); |
|
|
MS_EXCEPTION_IF_NULL(graph_ptr); |
|
|
@@ -324,6 +392,9 @@ void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspor |
|
|
graph_ptr->set_execution_order(cnodes); |
|
|
graph_ptr->set_execution_order(cnodes); |
|
|
total_event_num_ = cur_event_id; |
|
|
total_event_num_ = cur_event_id; |
|
|
MS_LOG(INFO) << "after insert send/recv for hcom parallel, total event nums[" << total_event_num_ << "]"; |
|
|
MS_LOG(INFO) << "after insert send/recv for hcom parallel, total event nums[" << total_event_num_ << "]"; |
|
|
|
|
|
|
|
|
|
|
|
// Insert Send/Recv between Hcom(such as:AllReduce1 Send1 Common Recv1 AllReduce2) |
|
|
|
|
|
InsertSendRecvForDiffHcom(graph_ptr); |
|
|
MS_LOG(INFO) << "end"; |
|
|
MS_LOG(INFO) << "end"; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|