From 7e6acd45cf8c32d574e643afed7be778c39d76f9 Mon Sep 17 00:00:00 2001 From: gukecai Date: Mon, 7 Dec 2020 14:39:33 +0800 Subject: [PATCH] fix --- .../device/ascend/ascend_stream_assign.cc | 87 +++++++++---------- 1 file changed, 41 insertions(+), 46 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index 9959a65966..3b278c2d31 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -985,61 +985,56 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto cnode_ptr_list = graph_ptr->execution_order(); - // key:group id, key: stream id, value:hcom index - std::map>>> group_hcom_index; - for (size_t i = 0; i < cnode_ptr_list.size(); i++) { - auto cur_cnode = cnode_ptr_list[i]; - if (!IsHcom(cur_cnode)) { - continue; - } - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); - if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode)) { - MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr"; - } - auto group_name = AnfAlgo::GetNodeAttr(cur_cnode, kAttrGroup); - MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name - << "; stream id:" << cur_stream_id; - auto iter = group_hcom_index.find(group_name); - if (iter == group_hcom_index.end()) { - std::vector>> hcom_index; - hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector{i})); - group_hcom_index[group_name] = hcom_index; - } else { - auto &hcom_index = iter->second; - bool exit = false; - for (auto &item : hcom_index) { - if (item.first == cur_stream_id) { - item.second.emplace_back(i); - exit = true; - break; - } + if (group_hcom_graph_map_.empty()) { + return; + } + std::vector groups; + for (const auto &item : group_hcom_graph_map_) { + groups.emplace_back(item.first); + } + for (const auto &group : groups) { + auto cnode_ptr_list = graph_ptr->execution_order(); + std::vector>> stream_indexs; + for (size_t i = 0; i < cnode_ptr_list.size(); i++) { + auto cur_cnode = cnode_ptr_list[i]; + if (!IsHcom(cur_cnode)) { + continue; } - if (!exit) { - hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector{i})); + + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); + if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode)) { + MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr"; + } + auto group_name = AnfAlgo::GetNodeAttr(cur_cnode, kAttrGroup); + MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name + << "; stream id:" << cur_stream_id; + if (group_name != group) { + continue; } - } - } - for (const auto &hcom_index : group_hcom_index) { - MS_LOG(DEBUG) << "Group:" << hcom_index.first; - for (const auto &item : hcom_index.second) { - MS_LOG(DEBUG) << "stream id:" << item.first; - for (const auto &index : item.second) { - MS_LOG(DEBUG) << "hcom index:" << index; + if (stream_indexs.empty()) { + stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector{i})); + } else { + bool exit = false; + for (auto &item : stream_indexs) { + if (item.first == cur_stream_id) { + item.second.emplace_back(i); + exit = true; + break; + } + } + if (!exit) { + stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector{i})); + } } } - } - for (const auto &hcom_index : group_hcom_index) { - if (hcom_index.second.size() < 2) { - MS_LOG(INFO) << "Group:" << hcom_index.first + if (stream_indexs.size() < 2) { + MS_LOG(INFO) << "Group:" << group << "; different stream hcom size is less than 2, no need insert event between them"; continue; } - InsertEventBetweenHcom(graph_ptr, hcom_index.second); - MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); + InsertEventBetweenHcom(graph_ptr, stream_indexs); } }