|
|
|
@@ -48,6 +48,12 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) |
|
|
|
CheckResourceAssign(graph_ptr); |
|
|
|
MS_LOG(INFO) << "After finish stream assign"; |
|
|
|
|
|
|
|
FindStreamRelations(graph_ptr); |
|
|
|
PrintStreamRelations(); |
|
|
|
GetStreamRelations(); |
|
|
|
PrintStreamGroups(); |
|
|
|
FindEventRelations(graph_ptr); |
|
|
|
|
|
|
|
// Get info for D Model |
|
|
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); |
|
|
|
generator::IRModelUtil::GetInstance().set_event_num(resource_manager.get_cur_event_num()); |
|
|
|
@@ -501,6 +507,8 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt |
|
|
|
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id); |
|
|
|
cnodes.emplace_back(recv); |
|
|
|
cnodes.emplace_back(cur_cnode_ptr); |
|
|
|
} else { |
|
|
|
cnodes.emplace_back(cur_cnode_ptr); |
|
|
|
} |
|
|
|
pre_stream_id = cur_stream_id; |
|
|
|
} |
|
|
|
@@ -910,7 +918,351 @@ void AscendStreamAssign::Reset() { |
|
|
|
common_stream_map_.clear(); |
|
|
|
processed_streams_.clear(); |
|
|
|
need_first_active_streams_.clear(); |
|
|
|
stream_groups_.clear(); |
|
|
|
stream_relations_.clear(); |
|
|
|
event_map_.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
// section 10 |
|
|
|
bool AscendStreamAssign::IsVecExist(std::vector<uint32_t> *group) { |
|
|
|
auto group_size = group->size(); |
|
|
|
if (group_size == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
for (const auto &item : stream_groups_) { |
|
|
|
if (item.size() < group->size()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
bool flag = true; |
|
|
|
for (size_t i = 0; i < group_size; i++) { |
|
|
|
if (item[i] != group->at(i)) { |
|
|
|
flag = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (flag) { |
|
|
|
return true; |
|
|
|
} else { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::DFS(uint32_t start, std::vector<uint32_t> *group) { |
|
|
|
auto it = stream_relations_.find(start); |
|
|
|
if (it == stream_relations_.end()) { |
|
|
|
if (!IsVecExist(group)) { |
|
|
|
stream_groups_.emplace_back(*group); |
|
|
|
} else { |
|
|
|
MS_LOG(WARNING) << "DFS should not print this log"; |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
vector<uint32_t> active_streams = stream_relations_[start]; |
|
|
|
|
|
|
|
for (const auto &item : active_streams) { |
|
|
|
group->emplace_back(item); |
|
|
|
DFS(item, group); |
|
|
|
group->pop_back(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::GetStreamRelations() { |
|
|
|
for (const auto &start : need_first_active_streams_) { |
|
|
|
vector<uint32_t> group{start}; |
|
|
|
DFS(start, &group); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr) { |
|
|
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); |
|
|
|
auto stream_num = resource_manager.get_cur_stream_num(); |
|
|
|
if (stream_num <= 1) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
auto exe_orders = graph_ptr->execution_order(); |
|
|
|
for (size_t i = 0; i < exe_orders.size(); i++) { |
|
|
|
auto cur_cnode = exe_orders[i]; |
|
|
|
auto name = AnfAlgo::GetCNodeName(cur_cnode); |
|
|
|
if (name != kStreamSwitchOpName && name != kStreamActiveOpName) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// support:streamswitch is begin of the stream |
|
|
|
if (name == kStreamSwitchOpName) { |
|
|
|
GetStreamSwitchStreamRelation(cur_cnode); |
|
|
|
} |
|
|
|
|
|
|
|
if (name == kStreamActiveOpName) { |
|
|
|
GetStreamActiveStreamRelation(graph_ptr, i); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::GetStreamSwitchStreamRelation(const CNodePtr &node_ptr) { |
|
|
|
MS_EXCEPTION_IF_NULL(node_ptr); |
|
|
|
auto cur_stream_id = AnfAlgo::GetStreamId(node_ptr); |
|
|
|
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(node_ptr, kAttrTrueBranchStream); |
|
|
|
if (true_stream_id <= cur_stream_id) { |
|
|
|
MS_LOG(ERROR) << "StreamSwitch self stream id " << cur_stream_id |
|
|
|
<< " is greater than true branch stream id:" << true_stream_id; |
|
|
|
} |
|
|
|
auto it = stream_relations_.find(cur_stream_id); |
|
|
|
if (it == stream_relations_.end()) { |
|
|
|
stream_relations_[cur_stream_id] = {true_stream_id}; |
|
|
|
} else { |
|
|
|
auto iter = |
|
|
|
std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), true_stream_id); |
|
|
|
if (iter == stream_relations_[cur_stream_id].end()) { |
|
|
|
stream_relations_[cur_stream_id].emplace_back(true_stream_id); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index) { |
|
|
|
StreamActiveKind kind = GetStreamActiveKind(graph_ptr, index); |
|
|
|
if (kind == kInvalid) { |
|
|
|
MS_LOG(INFO) << "Invalid streamActive kind"; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
auto orders = graph_ptr->execution_order(); |
|
|
|
auto cur_cnode = orders[index]; |
|
|
|
auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); |
|
|
|
auto active_list = AnfAlgo::GetNodeAttr<vector<uint32_t>>(cur_cnode, kAttrActiveStreamList); |
|
|
|
if (kind == kHead) { |
|
|
|
uint32_t active_current_node = GetStreamByActivedStream(cur_stream_id); |
|
|
|
if (active_current_node == kInvalidStreamId) { |
|
|
|
MS_LOG(EXCEPTION) << "No stream to active streamactive stream"; |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &item : active_list) { |
|
|
|
if (item <= active_current_node) { |
|
|
|
MS_LOG(WARNING) << "Actived stream is less than activing stream"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto it = |
|
|
|
std::find(stream_relations_[active_current_node].begin(), stream_relations_[active_current_node].end(), item); |
|
|
|
if (it == stream_relations_[active_current_node].end()) { |
|
|
|
stream_relations_[active_current_node].emplace_back(item); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (kind == kMiddle) { |
|
|
|
for (const auto &stream : active_list) { |
|
|
|
if (stream <= cur_stream_id) { |
|
|
|
MS_LOG(INFO) << "MIDDLE StreamActive active stream is less than self stream, no need deal"; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "MIDDLE StreamActive active stream is greater than self stream, should not be exit now"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (kind == kTail) { |
|
|
|
auto it = stream_relations_.find(cur_stream_id); |
|
|
|
if (it == stream_relations_.end()) { |
|
|
|
stream_relations_[cur_stream_id] = active_list; |
|
|
|
} else { |
|
|
|
for (const auto &stream : active_list) { |
|
|
|
if (stream <= cur_stream_id) { |
|
|
|
MS_LOG(WARNING) << "Actived stream is less than activing stream"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto iter = std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), stream); |
|
|
|
if (iter == stream_relations_[cur_stream_id].end()) { |
|
|
|
stream_relations_[cur_stream_id].emplace_back(stream); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull<KernelGraphPtr> &graph_ptr, size_t index) { |
|
|
|
auto exe_orders = graph_ptr->execution_order(); |
|
|
|
if (index >= exe_orders.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid op index:" << index; |
|
|
|
} |
|
|
|
|
|
|
|
auto cur_cnode = exe_orders[index]; |
|
|
|
auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); |
|
|
|
if (AnfAlgo::GetCNodeName(cur_cnode) != kStreamActiveOpName) { |
|
|
|
MS_LOG(EXCEPTION) << "Current node name is not StreamActive"; |
|
|
|
} |
|
|
|
|
|
|
|
if (index == 0) { |
|
|
|
return kInvalid; |
|
|
|
} |
|
|
|
|
|
|
|
if (index == exe_orders.size() - 1) { |
|
|
|
return kInvalid; |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t pre_stream_id = UINT32_MAX; |
|
|
|
uint32_t next_stream_id = UINT32_MAX; |
|
|
|
int32_t start = SizeToInt(index); |
|
|
|
for (int32_t i = start; i >= 0; i--) { |
|
|
|
auto cnode = exe_orders[IntToSize(i)]; |
|
|
|
auto name = AnfAlgo::GetCNodeName(cnode); |
|
|
|
if (name == kSendOpName || name == kRecvOpName) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
pre_stream_id = AnfAlgo::GetStreamId(cnode); |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t i = index + 1; i < exe_orders.size(); i++) { |
|
|
|
auto cnode = exe_orders[i]; |
|
|
|
auto name = AnfAlgo::GetCNodeName(cnode); |
|
|
|
if (name == kSendOpName || name == kRecvOpName) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
next_stream_id = AnfAlgo::GetStreamId(cnode); |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
// pre_stream_id = UINT32_MAX:means no node active current StreamActive |
|
|
|
// next_stream_id = UINT32_MAX:means current StreamActive active no node |
|
|
|
if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) { |
|
|
|
return kInvalid; |
|
|
|
} |
|
|
|
|
|
|
|
if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) { |
|
|
|
return kMiddle; |
|
|
|
} |
|
|
|
|
|
|
|
if (cur_stream_id == pre_stream_id) { |
|
|
|
return kTail; |
|
|
|
} |
|
|
|
|
|
|
|
if (cur_stream_id == next_stream_id) { |
|
|
|
return kHead; |
|
|
|
} |
|
|
|
|
|
|
|
return kInvalid; |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t AscendStreamAssign::GetStreamByActivedStream(uint32_t actived_stream_id) { |
|
|
|
if (stream_relations_.empty()) { |
|
|
|
return kInvalidStreamId; |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &item : stream_relations_) { |
|
|
|
auto it = std::find(item.second.begin(), item.second.end(), actived_stream_id); |
|
|
|
if (it != item.second.end()) { |
|
|
|
return item.first; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return kInvalidStreamId; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::PrintStreamRelations() { |
|
|
|
MS_LOG(INFO) << "Stream relations size:" << stream_relations_.size(); |
|
|
|
for (const auto &item : stream_relations_) { |
|
|
|
MS_LOG(INFO) << "Stream:" << item.first; |
|
|
|
for (const auto &stream : item.second) { |
|
|
|
MS_LOG(INFO) << "--actived stream id:" << stream; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::PrintStreamGroups() { |
|
|
|
MS_LOG(INFO) << "Stream group size:" << stream_groups_.size(); |
|
|
|
for (const auto &item : stream_groups_) { |
|
|
|
MS_LOG(INFO) << "Group:"; |
|
|
|
for (const auto &stream : item) { |
|
|
|
MS_LOG(INFO) << "Stream id:" << stream; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// section 11 |
|
|
|
bool AscendStreamAssign::IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const { |
|
|
|
size_t send_group = 0; |
|
|
|
size_t recv_group = 0; |
|
|
|
bool send_flag = true; |
|
|
|
bool recv_flag = true; |
|
|
|
for (size_t i = 0; i < stream_groups_.size(); i++) { |
|
|
|
auto group = stream_groups_[i]; |
|
|
|
if (send_flag) { |
|
|
|
auto it = std::find(group.begin(), group.end(), send_stream_id); |
|
|
|
if (it != group.end()) { |
|
|
|
send_group = i; |
|
|
|
send_flag = false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (recv_flag) { |
|
|
|
auto it = std::find(group.begin(), group.end(), recv_stream_id); |
|
|
|
if (it != group.end()) { |
|
|
|
recv_group = i; |
|
|
|
recv_flag = false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (!(send_flag || recv_flag)) { |
|
|
|
return (send_group != recv_group); |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::FindEventRelations(const NotNull<KernelGraphPtr> &graph_ptr) { |
|
|
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); |
|
|
|
auto event_nums = resource_manager.get_cur_event_num(); |
|
|
|
if (event_nums == 0) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto exe_orders = graph_ptr->execution_order(); |
|
|
|
// find all event info |
|
|
|
for (size_t i = 0; i < exe_orders.size(); i++) { |
|
|
|
auto cur_cnode = exe_orders[i]; |
|
|
|
auto name = AnfAlgo::GetCNodeName(cur_cnode); |
|
|
|
if (name == kSendOpName) { |
|
|
|
event_map_[cur_cnode] = {}; |
|
|
|
} |
|
|
|
|
|
|
|
if (name == kRecvOpName) { |
|
|
|
auto recv_event_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode, kAttrEventId); |
|
|
|
for (auto &item : event_map_) { |
|
|
|
auto send_event_id = AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId); |
|
|
|
if (recv_event_id == send_event_id) { |
|
|
|
item.second = cur_cnode; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// delete useless event info |
|
|
|
auto begin = event_map_.begin(); |
|
|
|
while (begin != event_map_.end()) { |
|
|
|
auto send_stream_id = AnfAlgo::GetStreamId(begin->first); |
|
|
|
auto recv_stream_id = AnfAlgo::GetStreamId(begin->second); |
|
|
|
bool flag = IsSatisfiedEvent(send_stream_id, recv_stream_id); |
|
|
|
if (!flag) { |
|
|
|
begin = event_map_.erase(begin); |
|
|
|
} else { |
|
|
|
begin++; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Satisfied event info"; |
|
|
|
for (const auto &item : event_map_) { |
|
|
|
MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace ascend |
|
|
|
} // namespace device |
|
|
|
} // namespace mindspore |