From 759a7012bb78615e0263e7d6ca46b9d38e4fac8a Mon Sep 17 00:00:00 2001 From: hwjiaorui Date: Sat, 17 Apr 2021 11:11:30 +0800 Subject: [PATCH] fix stream event --- .../device/ascend/ascend_stream_assign.cc | 75 ++++++++++--------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index 39b6c7b2e6..f8c8c4295e 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -1481,50 +1481,51 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull if (group_hcom_graph_map_.empty()) { return; } - std::vector groups; - for (const auto &item : group_hcom_graph_map_) { - groups.emplace_back(item.first); - } - for (const auto &group : groups) { - auto cnode_ptr_list = graph_ptr->execution_order(); - std::vector>> stream_indices; - for (size_t i = 0; i < cnode_ptr_list.size(); i++) { - auto cur_cnode = cnode_ptr_list[i]; - if (!IsHcom(cur_cnode)) { - continue; - } - - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); - auto group_name = GetHcomGroup(cur_cnode); - MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name - << "; stream id:" << cur_stream_id; - if (group_name != group) { - continue; - } + for (const auto &group_item : group_hcom_graph_map_) { + auto group = group_item.first; + for (const auto &graph_item : group_item.second) { + auto graph_id = graph_item.first; + auto cnode_ptr_list = graph_ptr->execution_order(); + std::vector>> stream_indices; + for (size_t i = 0; i < cnode_ptr_list.size(); i++) { + auto cur_cnode = cnode_ptr_list[i]; + if (!IsHcom(cur_cnode)) { + continue; + } - if (stream_indices.empty()) { - stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector{i})); - } else { - bool exit = false; - for (auto &item : stream_indices) { - if (item.first == cur_stream_id) { - item.second.emplace_back(i); - exit = true; - break; - } + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); + auto group_name = GetHcomGroup(cur_cnode); + auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode.get()); + MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name + << "; stream id:" << cur_stream_id; + if (group_name != group || cur_graph_id != graph_id) { + continue; } - if (!exit) { + + if (stream_indices.empty()) { stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector{i})); + } else { + bool exit = false; + for (auto &item : stream_indices) { + if (item.first == cur_stream_id) { + item.second.emplace_back(i); + exit = true; + break; + } + } + if (!exit) { + stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector{i})); + } } } - } - if (stream_indices.size() < 2) { - MS_LOG(INFO) << "Group:" << group - << "; different stream hcom size is less than 2, no need insert event between them"; - continue; + if (stream_indices.size() < 2) { + MS_LOG(INFO) << "Group:" << group + << "; different stream hcom size is less than 2, no need insert event between them"; + continue; + } + InsertEventBetweenHcom(graph_ptr, stream_indices); } - InsertEventBetweenHcom(graph_ptr, stream_indices); } }