| @@ -985,61 +985,56 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt | |||||
| } | } | ||||
| void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) { | void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) { | ||||
| AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); | |||||
| auto cnode_ptr_list = graph_ptr->execution_order(); | |||||
| // key:group id, key: stream id, value:hcom index | |||||
| std::map<std::string, std::vector<std::pair<uint32_t, vector<size_t>>>> group_hcom_index; | |||||
| for (size_t i = 0; i < cnode_ptr_list.size(); i++) { | |||||
| auto cur_cnode = cnode_ptr_list[i]; | |||||
| if (!IsHcom(cur_cnode)) { | |||||
| continue; | |||||
| } | |||||
| uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode)) { | |||||
| MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr"; | |||||
| } | |||||
| auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode, kAttrGroup); | |||||
| MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name | |||||
| << "; stream id:" << cur_stream_id; | |||||
| auto iter = group_hcom_index.find(group_name); | |||||
| if (iter == group_hcom_index.end()) { | |||||
| std::vector<std::pair<uint32_t, vector<size_t>>> hcom_index; | |||||
| hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i})); | |||||
| group_hcom_index[group_name] = hcom_index; | |||||
| } else { | |||||
| auto &hcom_index = iter->second; | |||||
| bool exit = false; | |||||
| for (auto &item : hcom_index) { | |||||
| if (item.first == cur_stream_id) { | |||||
| item.second.emplace_back(i); | |||||
| exit = true; | |||||
| break; | |||||
| } | |||||
| if (group_hcom_graph_map_.empty()) { | |||||
| return; | |||||
| } | |||||
| std::vector<string> groups; | |||||
| for (const auto &item : group_hcom_graph_map_) { | |||||
| groups.emplace_back(item.first); | |||||
| } | |||||
| for (const auto &group : groups) { | |||||
| auto cnode_ptr_list = graph_ptr->execution_order(); | |||||
| std::vector<std::pair<uint32_t, vector<size_t>>> stream_indexs; | |||||
| for (size_t i = 0; i < cnode_ptr_list.size(); i++) { | |||||
| auto cur_cnode = cnode_ptr_list[i]; | |||||
| if (!IsHcom(cur_cnode)) { | |||||
| continue; | |||||
| } | } | ||||
| if (!exit) { | |||||
| hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i})); | |||||
| uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); | |||||
| if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode)) { | |||||
| MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr"; | |||||
| } | |||||
| auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode, kAttrGroup); | |||||
| MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name | |||||
| << "; stream id:" << cur_stream_id; | |||||
| if (group_name != group) { | |||||
| continue; | |||||
| } | } | ||||
| } | |||||
| } | |||||
| for (const auto &hcom_index : group_hcom_index) { | |||||
| MS_LOG(DEBUG) << "Group:" << hcom_index.first; | |||||
| for (const auto &item : hcom_index.second) { | |||||
| MS_LOG(DEBUG) << "stream id:" << item.first; | |||||
| for (const auto &index : item.second) { | |||||
| MS_LOG(DEBUG) << "hcom index:" << index; | |||||
| if (stream_indexs.empty()) { | |||||
| stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i})); | |||||
| } else { | |||||
| bool exit = false; | |||||
| for (auto &item : stream_indexs) { | |||||
| if (item.first == cur_stream_id) { | |||||
| item.second.emplace_back(i); | |||||
| exit = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!exit) { | |||||
| stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i})); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| for (const auto &hcom_index : group_hcom_index) { | |||||
| if (hcom_index.second.size() < 2) { | |||||
| MS_LOG(INFO) << "Group:" << hcom_index.first | |||||
| if (stream_indexs.size() < 2) { | |||||
| MS_LOG(INFO) << "Group:" << group | |||||
| << "; different stream hcom size is less than 2, no need insert event between them"; | << "; different stream hcom size is less than 2, no need insert event between them"; | ||||
| continue; | continue; | ||||
| } | } | ||||
| InsertEventBetweenHcom(graph_ptr, hcom_index.second); | |||||
| MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); | |||||
| InsertEventBetweenHcom(graph_ptr, stream_indexs); | |||||
| } | } | ||||
| } | } | ||||