|
|
|
@@ -104,7 +104,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr> |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto res = FindTargetOp(begin, end, cur_independent); |
|
|
|
auto res = FindTargetOp(begin, end, cur_independent, false); |
|
|
|
if (res != end) { |
|
|
|
flag = true; |
|
|
|
exe_orders.emplace_back(cur_independent); |
|
|
|
@@ -247,10 +247,6 @@ void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) { |
|
|
|
} |
|
|
|
group_hcom_graph_map_[diff_group.first] = hcom_graph_map; |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &item : group_hcom_graph_map_) { |
|
|
|
MS_LOG_INFO << "group id:" << item.first << "; hcom stream nums:" << item.second.size(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) { |
|
|
|
@@ -787,7 +783,7 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt |
|
|
|
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); |
|
|
|
it = cnodes.insert(it + 1, send_cnode_ptr); |
|
|
|
|
|
|
|
auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); |
|
|
|
auto target = FindTargetOp(it, cnodes.end(), *(it - 1), true); |
|
|
|
if (target == cnodes.end()) { |
|
|
|
MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope() |
|
|
|
<< ", can't find target for insert recv op, no insert send/recv"; |
|
|
|
@@ -795,11 +791,6 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsHcom(*target)) { |
|
|
|
it = cnodes.erase(it); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// deal recv op |
|
|
|
uint32_t stream_id = AnfAlgo::GetStreamId(*target); |
|
|
|
CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id); |
|
|
|
@@ -834,15 +825,26 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap |
|
|
|
} |
|
|
|
|
|
|
|
// get the input which located in the lastr exe orders |
|
|
|
auto last_input_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr); |
|
|
|
auto it = std::find(cnodes.begin(), cnodes.end(), last_input_cnode); |
|
|
|
if (it == cnodes.end()) { |
|
|
|
MS_LOG(ERROR) << "Hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) |
|
|
|
<< "get last input:" << AnfAlgo::GetCNodeName(last_input_cnode) << "; but last input not in cnodes"; |
|
|
|
} else { |
|
|
|
vector<CNodePtr> inputs_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr); |
|
|
|
if (inputs_cnode.empty()) { |
|
|
|
MS_LOG(WARNING) << "Hcom op:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << " can't find inputs nodes"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Current hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) |
|
|
|
<< "; inputs cnode size:" << inputs_cnode.size(); |
|
|
|
|
|
|
|
for (size_t j = 0; j < inputs_cnode.size(); j++) { |
|
|
|
auto &cur_input = inputs_cnode.at(j); |
|
|
|
MS_LOG(INFO) << "The index:" << j << " input, name:" << AnfAlgo::GetCNodeName(cur_input); |
|
|
|
uint32_t cur_event_id = resource_manager.ApplyNewEvent(); |
|
|
|
auto last_stream_id = AnfAlgo::GetStreamId(last_input_cnode); |
|
|
|
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, last_stream_id); |
|
|
|
auto pre_stream_id = AnfAlgo::GetStreamId(cur_input); |
|
|
|
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id); |
|
|
|
auto it = std::find(cnodes.begin(), cnodes.end(), cur_input); |
|
|
|
if (it == cnodes.end()) { |
|
|
|
MS_LOG_EXCEPTION << "Hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) |
|
|
|
<< " can't find input node:" << AnfAlgo::GetCNodeName(cur_input); |
|
|
|
} |
|
|
|
cnodes.insert(it + 1, send); |
|
|
|
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); |
|
|
|
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id); |
|
|
|
@@ -855,26 +857,56 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap |
|
|
|
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num(); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, |
|
|
|
const CNodePtr &cur_cnode_ptr) { |
|
|
|
vector<CNodePtr> AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, |
|
|
|
const CNodePtr &cur_cnode_ptr) { |
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order(); |
|
|
|
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup); |
|
|
|
auto input_cnodes = GetInputKernels(cur_cnode_ptr); |
|
|
|
if (input_cnodes.empty()) { |
|
|
|
return nullptr; |
|
|
|
return {}; |
|
|
|
} |
|
|
|
// record max index node for each stream |
|
|
|
std::map<uint32_t, std::pair<CNodePtr, uint32_t>> result; |
|
|
|
for (size_t i = 0; i < input_cnodes.size(); i++) { |
|
|
|
auto &cur_input = input_cnodes.at(i); |
|
|
|
auto stream_id = AnfAlgo::GetStreamId(cur_input); |
|
|
|
auto cur_index = GetIndexByKey(graph_ptr, cur_input.get()); |
|
|
|
if (cur_index == UINT32_MAX) { |
|
|
|
MS_LOG_EXCEPTION << "The input node:" << AnfAlgo::GetCNodeName(cur_input) << " is not found in graph"; |
|
|
|
} |
|
|
|
auto it = result.find(stream_id); |
|
|
|
if (it == result.end()) { |
|
|
|
result[stream_id] = std::make_pair(cur_input, cur_index); |
|
|
|
} else { |
|
|
|
auto max_index = it->second.second; |
|
|
|
if (cur_index > max_index) { |
|
|
|
result[stream_id] = std::make_pair(cur_input, cur_index); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
auto it_pos = cnode_ptr_list.begin(); |
|
|
|
|
|
|
|
for (auto &cnode : input_cnodes) { |
|
|
|
auto it = std::find(it_pos, cnode_ptr_list.end(), cnode); |
|
|
|
if (it != cnode_ptr_list.end()) { |
|
|
|
it_pos = it; |
|
|
|
vector<CNodePtr> final_inputs; |
|
|
|
uint32_t max = 0; |
|
|
|
CNodePtr max_common_cnode = nullptr; |
|
|
|
for (const auto &item : result) { |
|
|
|
if (IsHcom(item.second.first)) { |
|
|
|
auto cur_group = AnfAlgo::GetNodeAttr<std::string>(item.second.first, kAttrGroup); |
|
|
|
if (cur_group == group_name) { |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
final_inputs.emplace_back(item.second.first); |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (item.second.second > max) { |
|
|
|
max_common_cnode = item.second.first; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (it_pos == cnode_ptr_list.begin() && *it_pos != input_cnodes.front()) { |
|
|
|
MS_LOG(ERROR) << "The input of node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "was not found"; |
|
|
|
} |
|
|
|
|
|
|
|
return *it_pos; |
|
|
|
if (max_common_cnode != nullptr) { |
|
|
|
final_inputs.emplace_back(max_common_cnode); |
|
|
|
} |
|
|
|
return final_inputs; |
|
|
|
} |
|
|
|
|
|
|
|
vector<CNodePtr> AscendStreamAssign::GetInputKernels(const CNodePtr &node) { |
|
|
|
@@ -956,9 +988,7 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> |
|
|
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); |
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order(); |
|
|
|
// key:group id, key: stream id, value:hcom index |
|
|
|
std::map<std::string, std::map<uint32_t, vector<size_t>>> group_hcom_index; |
|
|
|
std::map<std::string, uint32_t> group_first_hcom_stream; |
|
|
|
std::map<std::string, uint32_t> group_last_hcom_stream; |
|
|
|
std::map<std::string, std::vector<std::pair<uint32_t, vector<size_t>>>> group_hcom_index; |
|
|
|
for (size_t i = 0; i < cnode_ptr_list.size(); i++) { |
|
|
|
auto cur_cnode = cnode_ptr_list[i]; |
|
|
|
if (!IsHcom(cur_cnode)) { |
|
|
|
@@ -969,67 +999,60 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> |
|
|
|
MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr"; |
|
|
|
} |
|
|
|
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode, kAttrGroup); |
|
|
|
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name |
|
|
|
<< "; stream id:" << cur_stream_id; |
|
|
|
auto iter = group_hcom_index.find(group_name); |
|
|
|
if (iter == group_hcom_index.end()) { |
|
|
|
std::map<uint32_t, vector<size_t>> hcom_index; |
|
|
|
hcom_index[cur_stream_id] = {i}; |
|
|
|
std::vector<std::pair<uint32_t, vector<size_t>>> hcom_index; |
|
|
|
hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i})); |
|
|
|
group_hcom_index[group_name] = hcom_index; |
|
|
|
} else { |
|
|
|
auto &hcom_index = iter->second; |
|
|
|
auto it = hcom_index.find(cur_stream_id); |
|
|
|
if (it == hcom_index.end()) { |
|
|
|
hcom_index[cur_stream_id] = {i}; |
|
|
|
} else { |
|
|
|
it->second.emplace_back(i); |
|
|
|
bool exit = false; |
|
|
|
for (auto &item : hcom_index) { |
|
|
|
if (item.first == cur_stream_id) { |
|
|
|
item.second.emplace_back(i); |
|
|
|
exit = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!exit) { |
|
|
|
hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i})); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// record first hcom stream id |
|
|
|
auto it = group_first_hcom_stream.find(group_name); |
|
|
|
if (it == group_first_hcom_stream.end()) { |
|
|
|
group_first_hcom_stream[group_name] = cur_stream_id; |
|
|
|
} |
|
|
|
|
|
|
|
// record last hcom stream id |
|
|
|
it = group_last_hcom_stream.find(group_name); |
|
|
|
if (it != group_last_hcom_stream.end()) { |
|
|
|
it->second = cur_stream_id; |
|
|
|
} else { |
|
|
|
group_last_hcom_stream[group_name] = cur_stream_id; |
|
|
|
for (const auto &hcom_index : group_hcom_index) { |
|
|
|
MS_LOG(DEBUG) << "Group:" << hcom_index.first; |
|
|
|
for (const auto &item : hcom_index.second) { |
|
|
|
MS_LOG(DEBUG) << "stream id:" << item.first; |
|
|
|
for (const auto &index : item.second) { |
|
|
|
MS_LOG(DEBUG) << "hcom index:" << index; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &hcom_index : group_hcom_index) { |
|
|
|
if (hcom_index.second.size() < 2) { |
|
|
|
MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them"; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto group_name = hcom_index.first; |
|
|
|
auto it = group_first_hcom_stream.find(group_name); |
|
|
|
if (it == group_first_hcom_stream.end()) { |
|
|
|
MS_LOG_EXCEPTION << "Can't find first hcom stream, hcom group id:" << group_name; |
|
|
|
} |
|
|
|
auto first_hcom_stream = it->second; |
|
|
|
|
|
|
|
it = group_last_hcom_stream.find(group_name); |
|
|
|
if (it == group_last_hcom_stream.end()) { |
|
|
|
MS_LOG_EXCEPTION << "Can't find last hcom stream, hcom group id:" << group_name; |
|
|
|
MS_LOG(INFO) << "Group:" << hcom_index.first |
|
|
|
<< "; different stream hcom size is less than 2, no need insert event between them"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto last_hcom_stream = it->second; |
|
|
|
InsertEventBetweenHcom(graph_ptr, hcom_index.second, first_hcom_stream, last_hcom_stream); |
|
|
|
InsertEventBetweenHcom(graph_ptr, hcom_index.second); |
|
|
|
MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, |
|
|
|
const map<uint32_t, vector<size_t>> &hcom_index, |
|
|
|
uint32_t first_hcom_stream, uint32_t last_hcom_stream) { |
|
|
|
const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index) { |
|
|
|
vector<CNodePtr> orders; |
|
|
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); |
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order(); |
|
|
|
uint32_t cur_event_id = resource_manager.ApplyNewEvent(); |
|
|
|
size_t first_stream_last_index = hcom_index.at(first_hcom_stream).back(); |
|
|
|
size_t last_stream_first_index = hcom_index.at(last_hcom_stream).front(); |
|
|
|
size_t first_stream_last_index = hcom_index[0].second.back(); |
|
|
|
size_t last_stream_first_index = hcom_index.back().second.front(); |
|
|
|
MS_LOG(INFO) << "First stream last index:" << first_stream_last_index |
|
|
|
<< "; last stream first index:" << last_stream_first_index; |
|
|
|
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_stream_last_index, std::back_inserter(orders)); |
|
|
|
for (size_t i = first_stream_last_index; i <= last_stream_first_index; i++) { |
|
|
|
auto cur_cnode = cnode_ptr_list[i]; |
|
|
|
@@ -1049,7 +1072,17 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &g |
|
|
|
orders.emplace_back(recv); |
|
|
|
orders.emplace_back(cur_cnode); |
|
|
|
} else { |
|
|
|
auto cur_stream_hcom_size = hcom_index.at(cur_hcom_stream_id).size(); |
|
|
|
size_t cur_stream_hcom_size = UINT32_MAX; |
|
|
|
size_t first_index = UINT32_MAX; |
|
|
|
size_t last_index = UINT32_MAX; |
|
|
|
for (const auto &item : hcom_index) { |
|
|
|
if (item.first == cur_hcom_stream_id) { |
|
|
|
cur_stream_hcom_size = item.second.size(); |
|
|
|
first_index = item.second.front(); |
|
|
|
last_index = item.second.back(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (cur_stream_hcom_size == 1) { |
|
|
|
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); |
|
|
|
orders.emplace_back(recv); |
|
|
|
@@ -1059,12 +1092,12 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &g |
|
|
|
orders.emplace_back(send); |
|
|
|
} else { |
|
|
|
// current stream, first hcom:add recv op |
|
|
|
if (i == hcom_index.at(cur_hcom_stream_id).front()) { |
|
|
|
if (i == first_index) { |
|
|
|
auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); |
|
|
|
orders.emplace_back(recv); |
|
|
|
cur_event_id = resource_manager.ApplyNewEvent(); |
|
|
|
orders.emplace_back(cur_cnode); |
|
|
|
} else if (i == hcom_index.at(cur_hcom_stream_id).back()) { |
|
|
|
} else if (i == last_index) { |
|
|
|
// current stream, last hcom:add send op |
|
|
|
orders.emplace_back(cur_cnode); |
|
|
|
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); |
|
|
|
@@ -1080,19 +1113,19 @@ void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &g |
|
|
|
graph_ptr->set_execution_order(orders); |
|
|
|
} |
|
|
|
|
|
|
|
bool AscendStreamAssign::IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>> &hcom_index, const CNodePtr &node_ptr, |
|
|
|
size_t index) { |
|
|
|
bool AscendStreamAssign::IsSatisfiedHcom(const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index, |
|
|
|
const CNodePtr &node_ptr, size_t index) { |
|
|
|
MS_EXCEPTION_IF_NULL(node_ptr); |
|
|
|
auto cur_hcom_stream_id = AnfAlgo::GetStreamId(node_ptr); |
|
|
|
auto it = hcom_index.find(cur_hcom_stream_id); |
|
|
|
if (it == hcom_index.end()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto iter = std::find(hcom_index.at(cur_hcom_stream_id).begin(), hcom_index.at(cur_hcom_stream_id).end(), index); |
|
|
|
if (iter == hcom_index.at(cur_hcom_stream_id).end()) { |
|
|
|
return false; |
|
|
|
for (const auto &item : hcom_index) { |
|
|
|
if (item.first == cur_hcom_stream_id) { |
|
|
|
auto it = std::find(item.second.begin(), item.second.end(), index); |
|
|
|
if (it != item.second.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// section6 |
|
|
|
@@ -1110,7 +1143,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG |
|
|
|
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); |
|
|
|
it = cnodes.insert(it + 1, send_cnode_ptr); |
|
|
|
|
|
|
|
auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); |
|
|
|
auto target = FindTargetOp(it, cnodes.end(), *(it - 1), false); |
|
|
|
if (target == cnodes.end()) { |
|
|
|
MS_LOG(DEBUG) << "Independ node[" << (*(it - 1))->fullname_with_scope() |
|
|
|
<< "] can't find target for insert recv op, no insert send/recv"; |
|
|
|
@@ -1441,7 +1474,8 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull<KernelGraphPtr> |
|
|
|
} |
|
|
|
|
|
|
|
vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::iterator begin, |
|
|
|
vector<CNodePtr>::iterator end, const CNodePtr &node) { |
|
|
|
vector<CNodePtr>::iterator end, const CNodePtr &node, |
|
|
|
bool exclude_hcom) { |
|
|
|
while (begin != end) { |
|
|
|
auto inputs = (*begin)->inputs(); |
|
|
|
for (size_t i = 1; i < inputs.size(); i++) { |
|
|
|
@@ -1451,16 +1485,22 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it |
|
|
|
auto new_inputs = cnode->inputs(); |
|
|
|
for (size_t j = 1; j < new_inputs.size(); j++) { |
|
|
|
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0); |
|
|
|
// find target node except hcom op. insert event for hcom in:InsertEventHcomDependCommonBak function |
|
|
|
// only insert one time |
|
|
|
if (node == new_real_input.first) { |
|
|
|
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]"; |
|
|
|
return begin; |
|
|
|
if (!(exclude_hcom && IsHcom(*begin))) { |
|
|
|
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]"; |
|
|
|
return begin; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto real_input = AnfAlgo::VisitKernel(input, 0); |
|
|
|
if (node == real_input.first) { |
|
|
|
MS_LOG(DEBUG) << "Find target op[" << (*begin)->DebugString() << "]"; |
|
|
|
return begin; |
|
|
|
if (!(exclude_hcom && IsHcom(*begin))) { |
|
|
|
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]"; |
|
|
|
return begin; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|