From bbbfaa244158388485132ac057abe3f961e657c3 Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Sat, 20 Jun 2020 01:54:36 +0800 Subject: [PATCH] enable new control sink Signed-off-by: zhoufeng --- .../device/ascend/ascend_kernel_runtime.cc | 16 +- .../device/ascend/ascend_stream_assign.cc | 786 ++++++++---------- .../device/ascend/ascend_stream_assign.h | 105 ++- mindspore/ccsrc/device/kernel_adjust.cc | 143 ++-- mindspore/ccsrc/device/kernel_adjust.h | 9 +- mindspore/ccsrc/device/kernel_info.h | 2 +- mindspore/ccsrc/device/kernel_runtime.cc | 25 +- mindspore/ccsrc/kernel/rts/label_switch.cc | 1 - mindspore/ccsrc/pipeline/action.cc | 19 +- .../ccsrc/session/anf_runtime_algorithm.cc | 18 + .../ccsrc/session/anf_runtime_algorithm.h | 1 + .../ccsrc/session/ascend_control_parser.cc | 62 +- .../ccsrc/session/ascend_control_parser.h | 1 + mindspore/ccsrc/session/ascend_session.cc | 8 +- mindspore/ccsrc/session/kernel_graph.cc | 89 +- mindspore/ccsrc/session/kernel_graph.h | 5 +- mindspore/ccsrc/session/session_basic.cc | 18 +- mindspore/ccsrc/vm/transform.cc | 1 + mindspore/ccsrc/vm/transform.h | 2 +- .../tasksink/ascend_stream_assign_stub.cc | 7 +- 20 files changed, 644 insertions(+), 674 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index 5a89c80692..e874c99b09 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -340,15 +340,17 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { return true; } - AscendStreamAssign &stream_assign_instance = AscendStreamAssign::GetInstance(); + AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); + AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); // the streams' flag not HEAD_STREAM std::vector wait_active_stream_list; - stream_assign_instance.GetWaitStreams(&wait_active_stream_list); - auto force_copy_stream_list = stream_assign_instance.hcom_streams(); + assign_instance.GetWaitStreams(&wait_active_stream_list); + std::vector force_copy_stream_list; + assign_instance.GetHcomStreams(&force_copy_stream_list); - MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_assign_instance.GetTotalStreamNum() - << ", total event num:" << stream_assign_instance.total_event_num() + MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_manager.GetCurAllocStreamNum() + << ", total event num:" << assign_instance.total_event_num() << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) << ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", force_copy_stream_list size:" << force_copy_stream_list.size(); @@ -356,8 +358,8 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { std::vector> empty_list; std::shared_ptr model = std::make_shared( task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, - 0, 0, 0, 0, 0, stream_assign_instance.GetTotalStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), - stream_assign_instance.total_event_num(), 0); + 0, 0, 0, 0, 0, stream_manager.GetCurAllocStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), + assign_instance.total_event_num(), 0); auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); if (!ret.second) { diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc index d637cfb90a..125630fe22 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc @@ -33,238 +33,220 @@ namespace device { namespace ascend { const uint32_t kHcomMaxTask = 5; const uint32_t kCommonMaxTask = 350; -const uint32_t kIndependFirstStreamId = 1024; -bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) { - MS_EXCEPTION_IF_NULL(apply_kernel); - return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL; -} - -void AscendStreamAssign::ResetNew() { - total_common_stream_num_ = 0; - total_independ_stream_num_ = 0; - total_event_num_ = 0; - first_physic_id_ = UINT32_MAX; - first_logic_id_ = UINT32_MAX; - independent_id_ = kIndependFirstStreamId; - logic_to_independent_map_.clear(); - processed_logic_id_.clear(); - logic_to_physic_map_.clear(); - independent_before_physic_id_.clear(); - inner_parallel_streams_.clear(); - processed_parallel_streams_.clear(); - hcom_stream_list_.clear(); - need_first_active_streams_.clear(); -} - -void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t processing_logic_id) { - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - auto it = logic_to_independent_map_.find(processing_logic_id); - if (it == logic_to_independent_map_.end()) { - (void)logic_to_independent_map_.insert(std::make_pair(processing_logic_id, independent_id_)); - AnfAlgo::SetStreamId(independent_id_, cur_cnode_ptr.get()); - independent_id_++; - } else { - AnfAlgo::SetStreamId(it->second, cur_cnode_ptr.get()); - } - - if (first_physic_id_ == UINT32_MAX) { - auto res = std::find(independent_before_physic_id_.begin(), independent_before_physic_id_.end(), - AnfAlgo::GetStreamId(cur_cnode_ptr)); - if (res == independent_before_physic_id_.end()) { - independent_before_physic_id_.push_back(AnfAlgo::GetStreamId(cur_cnode_ptr)); - } - } -} - -void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, - uint32_t *cur_index, uint32_t *cur_stream_id) { - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - MS_EXCEPTION_IF_NULL(*pre_cnode_ptr); - bool over_max_hcom_task = (IsHcom(cur_cnode_ptr) && (*cur_index) % kHcomMaxTask == 0); - bool over_max_common_task = (!IsHcom(cur_cnode_ptr) && (*cur_index) % kCommonMaxTask == 0); - bool pre_common_cur_hcom = (IsHcom(cur_cnode_ptr) && !IsHcom(*pre_cnode_ptr)); - bool pre_hcom_cur_common = (!IsHcom(cur_cnode_ptr) && IsHcom(*pre_cnode_ptr)); - if (over_max_hcom_task || over_max_common_task || pre_common_cur_hcom || pre_hcom_cur_common) { - *cur_index = 0; - ++(*cur_stream_id); - } - - if (over_max_hcom_task || pre_common_cur_hcom) { - hcom_stream_list_.emplace_back(*cur_stream_id); - } - ++(*cur_index); - AnfAlgo::SetStreamId(*cur_stream_id, cur_cnode_ptr.get()); - *pre_cnode_ptr = cur_cnode_ptr; -} - -bool AscendStreamAssign::IsProcessed(uint32_t logic_id) { - auto it = std::find(processed_logic_id_.begin(), processed_logic_id_.end(), logic_id); - if (it == processed_logic_id_.end()) { - return false; - } - - return true; -} +void AscendStreamAssign::AssignStream(const shared_ptr &graph_ptr) { + if (IsTaskSink()) { + Reset(); + ReorderIndependentOrders(graph_ptr); + AssignAllNodesStream(graph_ptr); + UpdateAtomicAddrCleanStreamId(graph_ptr); + FindHcomParallelStreams(graph_ptr); + InsertStreamActive(graph_ptr); + InsertSendRecvForHcomParallel(graph_ptr); + InsertSendRecvForIndependent(graph_ptr); + UpdateEventId(graph_ptr); + GetNeedActiveStreams(graph_ptr); + graph_ptr->PrintGraphExecuteOrder(); + CheckStreamAssign(graph_ptr); + MS_LOG(INFO) << "after finish stream assign"; -void AscendStreamAssign::RecordIdMap(uint32_t logic_id, uint32_t physic_id) { - auto it = logic_to_physic_map_.find(logic_id); - if (it == logic_to_physic_map_.end()) { - MS_LOG(INFO) << "New logic_id[" << logic_id << "] to physic_id[" << physic_id << "]"; - (void)logic_to_physic_map_.insert(std::make_pair(logic_id, physic_id)); + // Get info for D Model + AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); + generator::IRModelUtil::GetInstance().set_event_num(total_event_num()); + generator::IRModelUtil::GetInstance().set_stream_num(stream_manager.GetCurAllocStreamNum()); + // Init to 1,temporarily + generator::IRModelUtil::GetInstance().set_batch_num(1); } } -void AscendStreamAssign::RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, - uint32_t cur_stream_id) { - AnfAlgo::SetStreamId(cur_stream_id, cur_cnode_ptr.get()); - RecordIdMap(cur_node_logic_id, cur_stream_id); - first_physic_id_ = cur_stream_id; - first_logic_id_ = cur_node_logic_id; -} +// section 0 +void AscendStreamAssign::CheckStreamAssign(const shared_ptr &graph_ptr) { + MS_EXCEPTION_IF_NULL(graph_ptr); + std::set streams; + uint32_t max_stream = 0; + uint32_t min_stream = kInvalidStreamId; + const std::vector &cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + if (stream_id == kInvalidStreamId) { + MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "] had not been assigned streams"; + } -uint32_t AscendStreamAssign::GetLogicId(const CNodePtr &cur_cnode_ptr) { - uint32_t logic_id = AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()); - if (logic_id == kInvalidDistincLabel) { - MS_LOG(EXCEPTION) << "node[" << cur_cnode_ptr->DebugString() << "] logic id is invalid"; + streams.emplace(stream_id); + if (stream_id > max_stream) { + max_stream = stream_id; + } + if (stream_id < min_stream) { + min_stream = stream_id; + } } - return logic_id; -} -void AscendStreamAssign::SetCommonStreamNum(uint32_t cur_stream_id) { - if (first_physic_id_ == UINT32_MAX) { - MS_LOG(INFO) << "cur common node size is zero"; - total_common_stream_num_ = 0; - } else { - total_common_stream_num_ = cur_stream_id + 1; + if (!streams.empty()) { + if (min_stream != 0) { + MS_LOG(EXCEPTION) << "before stream assign, assigned stream should start from 0, now is from " << min_stream; + } + if (max_stream != (streams.size() - 1)) { + MS_LOG(EXCEPTION) << "before stream assign, assigned stream should be consecutive"; + } } } +// section 1 void AscendStreamAssign::AssignAllNodesStream(const shared_ptr &graph_ptr) { MS_EXCEPTION_IF_NULL(graph_ptr); auto cnode_ptr_list = graph_ptr->execution_order(); CNodePtr pre_cnode_ptr = nullptr; uint32_t cur_index = 0; uint32_t cur_stream_id = 0; - uint32_t processing_logic_id = UINT32_MAX; + bool exit_independent = false; + AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - // get logic id - uint32_t cur_node_logic_id = GetLogicId(cur_cnode_ptr); + if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { + continue; + } if (IsIndependentNode(cur_cnode_ptr)) { - AssignIndependentStreamId(cur_cnode_ptr, cur_node_logic_id); + exit_independent = true; continue; } + // first common node, only exe one time if (pre_cnode_ptr == nullptr) { - RecordFirstCommonOp(cur_cnode_ptr, cur_node_logic_id, cur_stream_id); - processing_logic_id = cur_node_logic_id; + uint32_t cur_stream_num = stream_manager.GetCurAllocStreamNum(); + if (cur_stream_num == 0) { + cur_stream_id = stream_manager.ApplyNewStream(); + } else { + cur_stream_id = stream_manager.GetCurAllocStream(); + } ++cur_index; pre_cnode_ptr = cur_cnode_ptr; + AnfAlgo::SetStreamId(cur_stream_id, cur_cnode_ptr.get()); + if (IsHcom(cur_cnode_ptr)) { + hcom_stream_list_.emplace(cur_stream_id); + } continue; } - // 1.has been processed - if (IsProcessed(cur_node_logic_id)) { - continue; - } + AssignCommonStreamId(cur_cnode_ptr, &pre_cnode_ptr, &cur_index, &cur_stream_id); + } - if (cur_node_logic_id == processing_logic_id) { - AssignCommonStreamId(cur_cnode_ptr, &pre_cnode_ptr, &cur_index, &cur_stream_id); - } else { - // 1.find other same logic id - for (size_t j = i; j < cnode_ptr_list.size(); ++j) { - CNodePtr cnode_ptr = cnode_ptr_list[j]; - MS_EXCEPTION_IF_NULL(cnode_ptr); - uint32_t logic_id = AnfAlgo::GetStreamDistinctionLabel(cnode_ptr.get()); - if (logic_id == processing_logic_id) { - AssignCommonStreamId(cnode_ptr, &pre_cnode_ptr, &cur_index, &cur_stream_id); - } + if (exit_independent) { + uint32_t first_independent_stream_id = stream_manager.ApplyNewStream(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { + continue; + } + if (IsIndependentNode(cur_cnode_ptr)) { + AssignIndependentStreamId(cur_cnode_ptr); } - // 2.after deal: - processed_logic_id_.push_back(processing_logic_id); - cur_cnode_ptr = cnode_ptr_list[i]; - // 3. new stream - ++cur_stream_id; - AnfAlgo::SetStreamId(cur_stream_id, cur_cnode_ptr.get()); - cur_index = 1; - - pre_cnode_ptr = cur_cnode_ptr; - processing_logic_id = cur_node_logic_id; - RecordIdMap(processing_logic_id, cur_stream_id); } + MS_LOG(INFO) << "independent start from :" << first_independent_stream_id; } - SetCommonStreamNum(cur_stream_id); - total_independ_stream_num_ = independent_id_ - kIndependFirstStreamId; - MS_LOG(INFO) << "stream nums:common:" << total_common_stream_num_ << ",independ:" << total_independ_stream_num_; + MS_LOG(INFO) << "total stream nums:" << stream_manager.GetCurAllocStreamNum(); } -void AscendStreamAssign::TransLogicToPhysic(const vector &logic_ids, vector *physic_ids) { - for (auto &id : logic_ids) { - auto it = logic_to_physic_map_.find(id); - if (it != logic_to_physic_map_.end()) { - MS_LOG(INFO) << "logic id[" << id << "] to physic id[" << it->second << "]"; - (*physic_ids).push_back(it->second); +void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr) { + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); + uint32_t cur_independent_id = stream_manager.GetCurAllocStream(); + auto it = independent_stream_map_.find(cur_independent_id); + if (it == independent_stream_map_.end()) { + AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); + independent_stream_map_.emplace(cur_independent_id, 1); + } else { + if (it->second < kCommonMaxTask) { + AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); + it->second++; } else { - MS_LOG(EXCEPTION) << "logic id[" << id << "] has no correspond physic id"; + cur_independent_id = stream_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); + independent_stream_map_.emplace(cur_independent_id, 1); } + } +} - auto it_independ = logic_to_independent_map_.find(id); - if (it_independ != logic_to_independent_map_.end()) { - MS_LOG(INFO) << "logic id[" << id << "] to independent id[" << it_independ->second << "]"; - (*physic_ids).push_back(it_independ->second); +bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { + MS_EXCEPTION_IF_NULL(node_ptr); + if (AnfAlgo::GetKernelType(node_ptr) != AICPU_KERNEL) { + return false; + } + + if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) { + MS_LOG(INFO) << "GetNext should not be independent node"; + return false; + } + + uint32_t input_nums = AnfAlgo::GetInputTensorNum(node_ptr); + if (input_nums == 0) { + MS_LOG(INFO) << "node " << node_ptr->fullname_with_scope() << " is independent, as inputs nums is zero"; + return true; + } + + const std::vector &inputs = node_ptr->inputs(); + for (size_t i = 1; i < inputs.size(); i++) { + if (!inputs[i]->isa()) { + return false; } } + MS_LOG(INFO) << "node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node"; + return true; } -void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { - MS_LOG(INFO) << "start update outter active op[" << active_ptr->DebugString() << "] "; - MS_EXCEPTION_IF_NULL(active_ptr); - auto primitive = AnfAlgo::GetCNodePrimitive(active_ptr); - MS_EXCEPTION_IF_NULL(primitive); - vector active_logic_ids = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); - // out StreamAcitve active physic stream is not parallel now, if parallel, should deal here. - vector active_physic_ids; - TransLogicToPhysic(active_logic_ids, &active_physic_ids); - ValuePtr active_physic_value = MakeValue>(active_physic_ids); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, active_physic_value, active_ptr); -} +void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, + uint32_t *cur_index, uint32_t *cur_stream_id) { + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + MS_EXCEPTION_IF_NULL(pre_cnode_ptr); + MS_EXCEPTION_IF_NULL(*pre_cnode_ptr); + AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); + bool over_max_hcom_task = (IsHcom(cur_cnode_ptr) && (*cur_index) % kHcomMaxTask == 0); + bool over_max_common_task = (!IsHcom(cur_cnode_ptr) && (*cur_index) % kCommonMaxTask == 0); + bool pre_common_cur_hcom = (IsHcom(cur_cnode_ptr) && !IsHcom(*pre_cnode_ptr)); + bool pre_hcom_cur_common = (!IsHcom(cur_cnode_ptr) && IsHcom(*pre_cnode_ptr)); + if (over_max_hcom_task || over_max_common_task || pre_common_cur_hcom || pre_hcom_cur_common) { + *cur_index = 0; + *cur_stream_id = stream_manager.ApplyNewStream(); + } -void AscendStreamAssign::UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr) { - MS_LOG(INFO) << "start update switch op[" << switch_ptr->DebugString() << "]"; - MS_EXCEPTION_IF_NULL(switch_ptr); - MS_EXCEPTION_IF_NULL(active_ptr); - auto primitive = AnfAlgo::GetCNodePrimitive(switch_ptr); - MS_EXCEPTION_IF_NULL(primitive); - auto true_logic_id = GetValue(primitive->GetAttr(kAttrTrueBranchStream)); - MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_logic_id - << "]"; - vector logic_ids{true_logic_id}; - vector physic_ids; - TransLogicToPhysic(logic_ids, &physic_ids); - if (physic_ids.empty()) { - MS_LOG(EXCEPTION) << "stream switch true logic id[" << true_logic_id << "] has no physical id"; + ++(*cur_index); + AnfAlgo::SetStreamId(*cur_stream_id, cur_cnode_ptr.get()); + *pre_cnode_ptr = cur_cnode_ptr; + + // record ll hcom streams as hcom stream has different stream flag + if (IsHcom(cur_cnode_ptr)) { + auto it = std::find(hcom_stream_list_.begin(), hcom_stream_list_.end(), *cur_stream_id); + if (it == hcom_stream_list_.end()) { + MS_LOG(INFO) << "hcom stream id:" << *cur_stream_id; + hcom_stream_list_.emplace(*cur_stream_id); + } } - ValuePtr true_index = MakeValue(physic_ids[0]); - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, true_index, switch_ptr); +} - MS_LOG(INFO) << "start update StreamActive op[" << active_ptr->DebugString() << "]"; - AnfAlgo::SetStreamId(physic_ids[0], active_ptr.get()); - vector active_ids; - for (size_t i = 0; i < physic_ids.size(); i++) { - if (i == 0) { - MS_LOG(INFO) << "StreamActive op self stream id[" << physic_ids[i] << "]"; - } else { - MS_LOG(INFO) << "StreamActive op active stream id[" << physic_ids[i] << "]"; - active_ids.emplace_back(physic_ids[i]); +// section 2: +void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const shared_ptr &graph_ptr) { + MS_LOG(INFO) << "start"; + MS_EXCEPTION_IF_NULL(graph_ptr); + const std::vector &cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + // update AtomicAddrClean stream same witch the next node + if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) { + MS_LOG(INFO) << "update AtomicAddrClean stream id from[" << AnfAlgo::GetStreamId(cnode_ptr_list[i - 1]) + << "] to [" << AnfAlgo::GetStreamId(cur_cnode_ptr) << "]"; + AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get()); } } - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_ids), active_ptr); + MS_LOG(INFO) << "end"; } -void AscendStreamAssign::FindAllReduceParallel(const shared_ptr &graph_ptr) { +// section 3 +void AscendStreamAssign::FindHcomParallelStreams(const shared_ptr &graph_ptr) { MS_EXCEPTION_IF_NULL(graph_ptr); CNodePtr cur_cnode_ptr = nullptr; CNodePtr pre_cnode_ptr = nullptr; @@ -280,9 +262,9 @@ void AscendStreamAssign::FindAllReduceParallel(const shared_ptr{pre_stream_id, cur_stream_id}); } @@ -291,6 +273,138 @@ void AscendStreamAssign::FindAllReduceParallel(const shared_ptr &graph_ptr, + const CNodePtr &switch_ptr, const vector &independent_stream, + vector *orders) { + MS_EXCEPTION_IF_NULL(orders); + orders->emplace_back(switch_ptr); + auto primitive = AnfAlgo::GetCNodePrimitive(switch_ptr); + MS_EXCEPTION_IF_NULL(primitive); + auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); + if (value_ptr == nullptr) { + return; + } + + auto need_active = GetValue(value_ptr); + if (!need_active) { + return; + } + + MS_LOG(INFO) << "start update switch op[" << switch_ptr->DebugString() << "]"; + MS_EXCEPTION_IF_NULL(switch_ptr); + auto true_stream_id = GetValue(primitive->GetAttr(kAttrTrueBranchStream)); + MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_stream_id + << "]"; + + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); + MS_LOG(INFO) << "start update StreamActive op[" << active_ptr->DebugString() << "]"; + AnfAlgo::SetStreamId(true_stream_id, active_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(independent_stream), active_ptr); + independent_stream_activated_ = true; + + // update processed stream + for (auto &item : independent_stream) { + processed_streams_.emplace(item); + } + + orders->emplace_back(active_ptr); +} // namespace ascend + +void AscendStreamAssign::InsertStreamActive(const std::shared_ptr &graph_ptr) { + MS_LOG(INFO) << "start"; + MS_EXCEPTION_IF_NULL(graph_ptr); + std::vector update_cnode_list; + CNodePtr cur_cnode_ptr = nullptr; + CNodePtr pre_cnode_ptr = nullptr; + uint32_t pre_stream_id = UINT32_MAX; + std::vector independent_stream; + MS_LOG(INFO) << "independent stream size:" << independent_stream_map_.size(); + for (auto item : independent_stream_map_) { + independent_stream.emplace_back(item.first); + } + + bool independent_flag = !(independent_stream.empty()); + + const std::vector &cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + if (IsIndependentNode(cur_cnode_ptr)) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + + bool inner_active = false; + if (pre_cnode_ptr != nullptr) { + inner_active = pre_stream_id != cur_stream_id && AnfAlgo::GetCNodeName(pre_cnode_ptr) != kStreamSwitchOpName && + AnfAlgo::GetCNodeName(pre_cnode_ptr) != kSendOpName; + } + + bool processed = IsProcessedStream(cur_stream_id); + // 1)inner stream assign, need insert active op + if (inner_active && !processed) { + MS_LOG(INFO) << "Inner insert active op, self stream id[" << pre_stream_id << "]"; + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); + // 1.set stream id + AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); + // 2.set active stream ids + std::vector active_index_list; + GetParallelStream(cur_stream_id, pre_stream_id, &active_index_list); + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); + update_cnode_list.emplace_back(active_ptr); + } + + if (independent_flag && (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName)) { + MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel"; + UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, independent_stream, &update_cnode_list); + } else { + update_cnode_list.emplace_back(cur_cnode_ptr); + } + + processed_streams_.emplace(cur_stream_id); + pre_stream_id = cur_stream_id; + pre_cnode_ptr = cur_cnode_ptr; + } + graph_ptr->set_execution_order(update_cnode_list); + MS_LOG(INFO) << "end"; +} + +bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { + auto it = std::find(processed_streams_.begin(), processed_streams_.end(), stream_id); + if (it != processed_streams_.end()) { + return true; + } + return false; +} + +void AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, + vector *parallel_streams) { + MS_EXCEPTION_IF_NULL(parallel_streams); + for (size_t i = 0; i < inner_parallel_streams_.size(); i++) { + const auto &cur_parallel_streams = inner_parallel_streams_[i]; + auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id); + if (it != cur_parallel_streams.end()) { + MS_LOG(INFO) << "stream id:" << cur_stream_id << " is parallel stream"; + for (size_t j = 0; j < cur_parallel_streams.size(); j++) { + if (cur_parallel_streams[j] == stream_acitve_id) { + MS_LOG(INFO) << "one of parallel stream id" << cur_parallel_streams[j] + << "is same with streamacvite stream id" << stream_acitve_id; + continue; + } + (*parallel_streams).emplace_back(cur_parallel_streams[j]); + processed_streams_.emplace(cur_parallel_streams[j]); + } + return; + } + } + + processed_streams_.emplace(cur_stream_id); + (*parallel_streams).push_back(cur_stream_id); +} + +// section5 void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr &graph_ptr) { MS_LOG(INFO) << "start"; MS_EXCEPTION_IF_NULL(graph_ptr); @@ -299,7 +413,7 @@ void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr orders; for (size_t i = 0; i < cnode_ptr_list.size(); i++) { auto cur_cnode = cnode_ptr_list[i]; - if (IsHcom(cur_cnode)) { + if (IsFusionHcom(cur_cnode)) { fusion_hcom_index.emplace_back(i); } } @@ -310,7 +424,7 @@ void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr *parallel_streams) { - for (size_t i = 0; i < inner_parallel_streams_.size(); i++) { - auto cur_parallel_streams = inner_parallel_streams_[i]; - auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id); - if (it != cur_parallel_streams.end()) { - MS_LOG(INFO) << "stream id:" << cur_stream_id << " is parallel stream"; - for (size_t j = 0; j < cur_parallel_streams.size(); j++) { - if (cur_parallel_streams[j] == stream_acitve_id) { - MS_LOG(INFO) << "one of parallel stream id" << cur_parallel_streams[j] - << "is same with streamacvite stream id" << stream_acitve_id; - continue; - } - (*parallel_streams).emplace_back(cur_parallel_streams[j]); - } - - // record processed parallel streams - (void)std::copy((*parallel_streams).begin(), (*parallel_streams).end(), - std::back_inserter(processed_parallel_streams_)); - return; - } - } - - (*parallel_streams).push_back(cur_stream_id); -} - -void AscendStreamAssign::InsertActiveNew(const std::shared_ptr &graph_ptr) { - MS_LOG(INFO) << "start"; - MS_EXCEPTION_IF_NULL(graph_ptr); - std::vector update_cnode_list; - CNodePtr cur_cnode_ptr = nullptr; - CNodePtr pre_cnode_ptr = nullptr; - uint32_t pre_stream_id = UINT32_MAX; - - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - if (cur_stream_id >= kIndependFirstStreamId) { - update_cnode_list.emplace_back(cur_cnode_ptr); - continue; - } - - bool inner_active = pre_stream_id != cur_stream_id && pre_stream_id < cur_stream_id && - AnfAlgo::GetCNodeName(pre_cnode_ptr) != kStreamSwitchOpName && - AnfAlgo::GetCNodeName(pre_cnode_ptr) != kStreamActiveOpName && - AnfAlgo::GetCNodeName(pre_cnode_ptr) != kSendOpName; - bool processed = IsProcessedParallelStream(cur_stream_id); - // 1)inner stream assign, need insert active op - if (inner_active && !processed) { - MS_LOG(INFO) << "Inner insert active op, self stream id[" << pre_stream_id << "]"; - CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); - update_cnode_list.emplace_back(active_ptr); - // 1.set stream id - AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); - // 2.set active stream ids - std::vector active_index_list; - GetParallelStream(cur_stream_id, pre_stream_id, &active_index_list); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); - } - // inner_active is not a if/else relationship with the next if/else. such as:StreamActive(S7)-->StreamActive(S8) - if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName && - AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) { - // 2)outter stream assign, update active op - update_cnode_list.emplace_back(cur_cnode_ptr); - UpdateStreamActive(cur_cnode_ptr); - } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { - // 3)update switch op - MS_LOG(INFO) << "Insert active op after switch"; - CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); - update_cnode_list.emplace_back(cur_cnode_ptr); - update_cnode_list.emplace_back(active_ptr); - UpdateStreamSwitch(cur_cnode_ptr, active_ptr); - } else { - update_cnode_list.emplace_back(cur_cnode_ptr); - } - - pre_stream_id = cur_stream_id; - pre_cnode_ptr = cur_cnode_ptr; - } - graph_ptr->set_execution_order(update_cnode_list); - MS_LOG(INFO) << "end"; -} - void AscendStreamAssign::UpdateEventId(const shared_ptr &graph_ptr) { MS_LOG(INFO) << "start"; MS_EXCEPTION_IF_NULL(graph_ptr); @@ -514,64 +540,11 @@ void AscendStreamAssign::UpdateEventId(const shared_ptr &g } } -void AscendStreamAssign::UpdateStreamId(const shared_ptr &graph_ptr) { - MS_LOG(INFO) << "start"; - MS_EXCEPTION_IF_NULL(graph_ptr); - CNodePtr cur_cnode_ptr = nullptr; - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - if (cur_stream_id < kIndependFirstStreamId) { - if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName) { - auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); - MS_EXCEPTION_IF_NULL(primitive); - vector active_ids = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); - for (size_t j = 0; j < active_ids.size(); j++) { - if (active_ids[j] >= kIndependFirstStreamId) { - active_ids[j] = active_ids[j] - kIndependFirstStreamId + total_common_stream_num_; - } - } - ValuePtr active_value = MakeValue>(active_ids); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, active_value, cur_cnode_ptr); - } - } else { - uint32_t update_id = cur_stream_id - kIndependFirstStreamId + total_common_stream_num_; - AnfAlgo::SetStreamId(update_id, cur_cnode_ptr.get()); - } - - // update AtomicAddrClean stream same witch the next node - if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == "AtomicAddrClean") { - MS_LOG(INFO) << "update AtomicAddrClean stream id from[" << AnfAlgo::GetStreamId(cnode_ptr_list[i - 1]) - << "] to [" << AnfAlgo::GetStreamId(cur_cnode_ptr) << "]"; - AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get()); - } - } - - // update logic_to_independent_map_ - for (auto &indep : logic_to_independent_map_) { - if (indep.second >= kIndependFirstStreamId) { - indep.second = indep.second - kIndependFirstStreamId + total_common_stream_num_; - } - } - - // update independent_before_physic_id_ - for (auto &id : independent_before_physic_id_) { - if (id >= kIndependFirstStreamId) { - id = id - kIndependFirstStreamId + total_common_stream_num_; - } - } - - // update independent_id_ - independent_id_ = independent_id_ - kIndependFirstStreamId + total_common_stream_num_; - MS_LOG(INFO) << "end"; -} - void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr &graph_ptr) { MS_EXCEPTION_IF_NULL(graph_ptr); CNodePtr cur_cnode_ptr = nullptr; auto cnode_ptr_list = graph_ptr->execution_order(); + // 1)stream witch kStreamNeedActivedFirst attr should be actived; for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { cur_cnode_ptr = cnode_ptr_list[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); @@ -589,29 +562,15 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr &graph_ptr) { - if (IsTaskSink()) { - ResetNew(); - ReorderIndependentOrders(graph_ptr); - AssignAllNodesStream(graph_ptr); - FindAllReduceParallel(graph_ptr); - InsertActiveNew(graph_ptr); - InsertSendRecvForHcomParallel(graph_ptr); - InsertSendRecvForIndependent(graph_ptr); - UpdateStreamId(graph_ptr); - UpdateEventId(graph_ptr); - GetNeedActiveStreams(graph_ptr); - - MS_LOG(INFO) << "after finish stream assign"; - graph_ptr->PrintGraphExecuteOrder(); + // 2)first stream 0 should be actived first; + need_first_active_streams_.emplace_back(0); - // Get info for D Model - generator::IRModelUtil::GetInstance().set_event_num(total_event_num()); - generator::IRModelUtil::GetInstance().set_stream_num(total_common_stream_num() + total_independ_stream_num()); - // Init to 1,temporarily - generator::IRModelUtil::GetInstance().set_batch_num(1); + // 3)independent stream:if has not been activate, push to need active vector + if (!independent_stream_activated_) { + for (auto &item : independent_stream_map_) { + need_first_active_streams_.emplace_back(item.first); + } } } @@ -722,33 +681,6 @@ void AscendStreamAssign::InsertSendRecvForIndependent(const shared_ptrfullname_with_scope() << " is independent, as inputs nums is zero"; - return true; - } - - auto inputs = node_ptr->inputs(); - for (size_t i = 1; i < inputs.size(); i++) { - if (!inputs[i]->isa()) { - return false; - } - } - MS_LOG(INFO) << "node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node"; - return true; -} - bool AscendStreamAssign::IsTaskSink() { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); @@ -762,56 +694,54 @@ bool AscendStreamAssign::IsTaskSink() { } void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { - if (total_common_stream_num_ == 0) { + MS_EXCEPTION_IF_NULL(wait_active_stream_list); + AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); + uint32_t total_stream_num = stream_manager.GetCurAllocStreamNum(); + if (total_stream_num == 0) { MS_LOG(INFO) << "total_common_stream_num is zero"; return; } // common stream:active first common stream - MS_LOG(INFO) << "active physic id[" << first_physic_id_ << "]"; - for (uint32_t i = first_physic_id_ + 1; i < total_common_stream_num_; i++) { + for (uint32_t i = 0; i < total_stream_num; i++) { auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i); if (it == need_first_active_streams_.end()) { MS_LOG(INFO) << "wait common stream id = " << i; (*wait_active_stream_list).push_back(i); } } +} - // all independ stream id before first physical stream id should be actived - auto it = logic_to_independent_map_.find(first_logic_id_); - if (it != logic_to_independent_map_.end()) { - uint32_t independent_id = it->second; - auto res = std::find(independent_before_physic_id_.begin(), independent_before_physic_id_.end(), independent_id); - if (res == independent_before_physic_id_.end()) { - // first physical to independ id may be not in independent_before_physic_id_ - independent_before_physic_id_.push_back(independent_id); - } - MS_LOG(INFO) << "active independent id[" << independent_id << "]"; +bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) { + MS_EXCEPTION_IF_NULL(apply_kernel); + return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL; +} + +bool AscendStreamAssign::IsFusionHcom(const CNodePtr &cur_cnode_ptr) { + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + bool is_hcom = IsHcom(cur_cnode_ptr); + if (!is_hcom) { + return false; } - uint32_t max_before_physic = 0; - for (size_t i = 0; i < independent_before_physic_id_.size(); i++) { - if (independent_before_physic_id_[i] > max_before_physic) { - max_before_physic = independent_before_physic_id_[i]; - } - MS_LOG(INFO) << "independent id[" << independent_before_physic_id_[i] << "] before first physic is active"; + if (!AnfAlgo::HasNodeAttr(kAttrFusion, cur_cnode_ptr)) { + return false; } - for (uint32_t i = 0; i < total_independ_stream_num_; i++) { - if (i + total_common_stream_num_ <= max_before_physic) { - continue; - } - // all wait streams should not in need_first_active_streams_ - auto iter = - std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i + total_common_stream_num_); - if (iter == need_first_active_streams_.end()) { - MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_; - (*wait_active_stream_list).push_back(i + total_common_stream_num_); - } + if (AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrFusion) == 0) { + return false; + } + + return true; +} + +void AscendStreamAssign::GetHcomStreams(std::vector *streams) { + MS_EXCEPTION_IF_NULL(streams); + for (const auto &stream : hcom_stream_list_) { + (*streams).emplace_back(stream); } } -uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; } void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr &graph_ptr) { MS_EXCEPTION_IF_NULL(graph_ptr); CNodePtr cur_cnode_ptr = nullptr; @@ -829,24 +759,19 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptrset_execution_order(exe_orders); + if (others.empty() || independents.empty()) { + MS_LOG(INFO) << "independent or others is empty, no need reorder"; return; } - if (independents.empty()) { - std::copy(others.begin(), others.end(), std::back_inserter(exe_orders)); - graph_ptr->set_execution_order(exe_orders); - return; - } - std::vector processed; + + std::set processed; for (size_t i = 0; i < others.size(); i++) { auto begin = others.begin() + i; auto end = begin + 1; bool flag = false; for (size_t j = 0; j < independents.size(); j++) { auto cur_independent = independents[j]; - auto it = std::find(processed.begin(), processed.end(), cur_independent); + auto it = std::find(processed.begin(), processed.end(), cur_independent.get()); if (it != processed.end()) { continue; } @@ -855,7 +780,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptrset_execution_order(exe_orders); } + +void AscendStreamAssign::Reset() { + total_event_num_ = 0; + independent_stream_activated_ = false; + independent_stream_map_.clear(); + processed_streams_.clear(); + hcom_stream_list_.clear(); + need_first_active_streams_.clear(); + inner_parallel_streams_.clear(); +} } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h index 4bb55a3d21..bb918cfc79 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include #include @@ -36,6 +38,36 @@ using std::shared_ptr; using std::unordered_map; using std::unordered_set; using std::vector; +using CnodeKey = void *; +const uint32_t kInvalidStreamId = UINT32_MAX; +class AscendStreamMng { + public: + static AscendStreamMng &GetInstance() { + static AscendStreamMng instance; + return instance; + } + + void Reset() { + cur_stream_id = 0; + cur_stream_num = 0; + } + uint32_t ApplyNewStream() { + if (!cur_stream_num) { + cur_stream_num++; + return cur_stream_id; + } + cur_stream_num++; + cur_stream_id++; + return cur_stream_id; + } + + uint32_t GetCurAllocStream() { return cur_stream_id; } + uint32_t GetCurAllocStreamNum() { return cur_stream_num; } + + private: + uint32_t cur_stream_num{0}; + uint32_t cur_stream_id{0}; +}; class AscendStreamAssign { public: @@ -47,22 +79,11 @@ class AscendStreamAssign { AscendStreamAssign(const AscendStreamAssign &) = delete; AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; - uint32_t GetTotalStreamNum() const; - // new stream policy - uint32_t total_common_stream_num() const { return total_common_stream_num_; } - uint32_t total_independ_stream_num() const { return total_independ_stream_num_; } uint32_t total_event_num() const { return total_event_num_; } + void GetHcomStreams(std::vector *streams); - void InsertActiveNew(const std::shared_ptr &graph_ptr); - void AssignAllNodesStream(const std::shared_ptr &graph_ptr); - void ResetNew(); - void AssignStreamNew(const std::shared_ptr &graph_ptr); - bool IsIndependentNode(const CNodePtr &node_ptr); - const std::unordered_map &logic_to_independent_map() { return logic_to_independent_map_; } - const std::unordered_map &logic_to_physic_map() { return logic_to_physic_map_; } - const std::vector> &inner_parallel_streams() { return inner_parallel_streams_; } + void AssignStream(const std::shared_ptr &graph_ptr); void GetWaitStreams(vector *wait_active_stream_list); - const std::vector &hcom_streams() { return hcom_stream_list_; } CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, uint32_t stream_id); CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, @@ -71,49 +92,41 @@ class AscendStreamAssign { private: AscendStreamAssign() = default; ~AscendStreamAssign() = default; - - vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, - const CNodePtr &node); - - bool IsHcom(const CNodePtr &apply_kernel); - bool IsProcessed(uint32_t logic_id); - void TransLogicToPhysic(const vector &logic_ids, vector *physic_ids); + void Reset(); + void CheckStreamAssign(const std::shared_ptr &graph_ptr); + void AssignAllNodesStream(const std::shared_ptr &graph_ptr); void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index, uint32_t *cur_stream_id); - void RecordIdMap(uint32_t logic_id, uint32_t physic_id); - void UpdateStreamActive(const CNodePtr &active_ptr); - void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr); - bool IsTaskSink(); - void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id); - void UpdateStreamId(const std::shared_ptr &graph_ptr); - void UpdateEventId(const std::shared_ptr &graph_ptr); - void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); - uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr); - void SetCommonStreamNum(uint32_t cur_stream_id); - void FindAllReduceParallel(const std::shared_ptr &graph_ptr); - bool IsProcessedParallelStream(uint32_t stream_id); - void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); + void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr); + void UpdateAtomicAddrCleanStreamId(const std::shared_ptr &graph_ptr); + void FindHcomParallelStreams(const std::shared_ptr &graph_ptr); + void InsertStreamActive(const std::shared_ptr &graph_ptr); + void UpdateStreamSwitch(const std::shared_ptr &graph_ptr, const CNodePtr &switch_ptr, + const vector &independent_stream, vector *orders); void InsertSendRecvForIndependent(const std::shared_ptr &graph_ptr); void InsertSendRecvForHcomParallel(const std::shared_ptr &graph_ptr); void InsertSendRecvForDiffHcom(const shared_ptr &graph_ptr); + void UpdateEventId(const std::shared_ptr &graph_ptr); void GetNeedActiveStreams(const std::shared_ptr &graph_ptr); void ReorderIndependentOrders(const std::shared_ptr &graph_ptr); - uint32_t total_common_stream_num_{0}; - uint32_t total_independ_stream_num_{0}; - uint32_t total_event_num_{0}; + bool IsTaskSink(); + bool IsFusionHcom(const CNodePtr &cur_cnode_ptr); + bool IsHcom(const CNodePtr &cur_cnode_ptr); + bool IsIndependentNode(const CNodePtr &node_ptr); + bool IsProcessedStream(uint32_t stream_id); + vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, + const CNodePtr &node); + void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); - uint32_t first_physic_id_{UINT32_MAX}; - uint32_t first_logic_id_{UINT32_MAX}; - uint32_t independent_id_{UINT32_MAX}; - vector processed_logic_id_{}; - std::unordered_map logic_to_physic_map_{}; // key:logic id, value: first physic id - std::unordered_map logic_to_independent_map_{}; // key:logic id, value: dependent id - std::vector independent_before_physic_id_{}; // record independent id before first physic id - std::vector> inner_parallel_streams_{}; - std::vector processed_parallel_streams_{}; - std::vector hcom_stream_list_{}; + uint32_t total_event_num_{0}; + bool independent_stream_activated_{false}; + std::map independent_stream_map_{}; + std::set processed_streams_{}; + std::set hcom_stream_list_{}; std::vector need_first_active_streams_{}; + std::vector> inner_parallel_streams_{}; + // new policy end }; } // namespace ascend diff --git a/mindspore/ccsrc/device/kernel_adjust.cc b/mindspore/ccsrc/device/kernel_adjust.cc index b4cecb6cd7..93007764af 100644 --- a/mindspore/ccsrc/device/kernel_adjust.cc +++ b/mindspore/ccsrc/device/kernel_adjust.cc @@ -37,24 +37,6 @@ namespace mindspore { namespace device { using device::ascend::ProfilingUtils; -void KernelAdjust::Reorder(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - const std::vector &origin_cnode_list = kernel_graph->execution_order(); - std::vector momentum_list; - std::vector other_list; - for (const auto &cnode : origin_cnode_list) { - if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(cnode)) != kOptOperatorSet.end()) { - momentum_list.emplace_back(cnode); - } else { - other_list.emplace_back(cnode); - } - } - std::vector new_order_list; - new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end()); - new_order_list.insert(new_order_list.end(), momentum_list.begin(), momentum_list.end()); - kernel_graph->set_execution_order(new_order_list); -} - void KernelAdjust::ReorderGetNext(const std::shared_ptr &kernel_graph_ptr) { MS_EXCEPTION_IF_NULL(kernel_graph_ptr); const std::vector &origin_cnode_list = kernel_graph_ptr->execution_order(); @@ -80,23 +62,6 @@ bool KernelAdjust::NeedInsertSwitch() { ConfigManager::GetInstance().iter_num() > 1); } -uint32_t KernelAdjust::FindFirstStreamSwitchLabel(const std::shared_ptr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - auto cnode_ptr_list = kernel_graph_ptr->execution_order(); - CNodePtr cur_cnode_ptr = nullptr; - uint32_t label = kInvalidDistincLabel; - for (uint32_t i = 0; i < cnode_ptr_list.size(); ++i) { - cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { - label = AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()); - break; - } - } - - return label; -} - CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id) { MS_EXCEPTION_IF_NULL(graph_ptr); @@ -138,6 +103,8 @@ CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr &kernel_graph_ptr) { + device::ascend::AscendStreamMng &stream_manager = device::ascend::AscendStreamMng::GetInstance(); + stream_manager.Reset(); if (!NeedInsertSwitch()) { return; } @@ -166,68 +133,62 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr if (orders.empty()) { MS_LOG(EXCEPTION) << "graph execution order is empty"; } - uint32_t first_cnode_stream_label = AnfAlgo::GetStreamDistinctionLabel(orders[0].get()); std::vector exec_order; - CNodePtr first_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(first_stream_switch_app); - AnfAlgo::SetStreamDistinctionLabel(kFirstStreamSwitchLabel, first_stream_switch_app.get()); - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(kGetNextLabel), first_stream_switch_app); - - CNodePtr second_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(second_stream_switch_app); - AnfAlgo::SetStreamDistinctionLabel(kSecondStreamSwitchLabel, second_stream_switch_app.get()); - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(first_cnode_stream_label), second_stream_switch_app); - // add attr "stream_need_active" - AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), second_stream_switch_app); - - CNodePtr first_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(first_stream_active_app); - AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, first_stream_active_app.get()); - std::vector first_active_streams = {kFirstStreamSwitchLabel}; - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(first_active_streams), - first_stream_active_app); - - CNodePtr second_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(second_stream_active_app); - // specific deal for common ctrl stream policy - uint32_t first_common_stream_switch_label = FindFirstStreamSwitchLabel(kernel_graph_ptr); - if (first_common_stream_switch_label == kInvalidDistincLabel) { - AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, second_stream_active_app.get()); - } else { - AnfAlgo::SetStreamDistinctionLabel(first_common_stream_switch_label, second_stream_active_app.get()); - } - - std::vector second_active_streams = {kSecondStreamSwitchLabel}; - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(second_active_streams), - second_stream_active_app); - CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(assign_add_one); - AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, assign_add_one.get()); - - CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId); - AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, send.get()); - CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId); - AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, recv.get()); + // getnext loop process + // getnext loop stream switch op + CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(getnext_switch_app); + uint32_t getnext_switch_stream_id = stream_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); + exec_order.push_back(getnext_switch_app); - // reorder graph orders - exec_order.push_back(first_stream_switch_app); + // getnext op + uint32_t getnext_stream_id = stream_manager.ApplyNewStream(); size_t i = 0; for (; i < orders.size(); i++) { auto node = orders[i]; exec_order.push_back(node); - AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, exec_order[exec_order.size() - 1].get()); + AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get()); if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { break; } } + // update getnext loop stream switch true_branch_stream attr + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(getnext_stream_id), getnext_switch_app); + + // getnext loop send + CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId); + AnfAlgo::SetStreamId(getnext_stream_id, send.get()); exec_order.push_back(send); - exec_order.push_back(second_stream_switch_app); + + // fpbp loop process + // fpbp loop stream switch + CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(fpbp_switch_app); + uint32_t fpbp_switch_stream_id = stream_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); + AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), fpbp_switch_app); + exec_order.push_back(fpbp_switch_app); + + // fpbp loop recv + CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId); + uint32_t fpbp_stream_id = stream_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(fpbp_stream_id, recv.get()); exec_order.push_back(recv); + + // update fpbp loop stream switch true_branch_stream attr + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(fpbp_stream_id), fpbp_switch_app); + + // fpbp loop AssignAdd + CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(assign_add_one); + AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get()); exec_order.push_back(assign_add_one); + // fpbp memcpy std::vector memcpy_list; std::vector before_list; std::vector after_list; @@ -244,12 +205,28 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr before_list.emplace_back(cur_cnode); } } - (void)std::copy(before_list.begin(), before_list.end(), std::back_inserter(exec_order)); (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); - exec_order.push_back(first_stream_active_app); + + // stream active to activate getnext loop + CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(getnext_active_app); + std::vector getnext_active_streams = {getnext_switch_stream_id}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(getnext_active_streams), + getnext_active_app); + exec_order.push_back(getnext_active_app); + + // fpbp loop other ops (void)std::copy(after_list.begin(), after_list.end(), std::back_inserter(exec_order)); - exec_order.push_back(second_stream_active_app); + + // stream active to activate fpbp loop + CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(fpbp_active_app); + // specific deal for common ctrl stream policy + std::vector fpbp_active_streams = {fpbp_switch_stream_id}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(fpbp_active_streams), fpbp_active_app); + exec_order.push_back(fpbp_active_app); + kernel_graph_ptr->set_execution_order(exec_order); } diff --git a/mindspore/ccsrc/device/kernel_adjust.h b/mindspore/ccsrc/device/kernel_adjust.h index feb97712be..1a7436b396 100644 --- a/mindspore/ccsrc/device/kernel_adjust.h +++ b/mindspore/ccsrc/device/kernel_adjust.h @@ -39,9 +39,9 @@ constexpr auto kZeroParamName = "zero"; constexpr auto kOneParamName = "one"; constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; -const uint32_t kFirstStreamSwitchLabel = kInvalidDistincLabel - 1; -const uint32_t kGetNextLabel = kInvalidDistincLabel - 2; -const uint32_t kSecondStreamSwitchLabel = kInvalidDistincLabel - 3; +const uint32_t kFirstStreamSwitchLabel = 0; +const uint32_t kGetNextLabel = 1; +const uint32_t kSecondStreamSwitchLabel = 2; const uint32_t kInvalidEventId = UINT32_MAX; const uint32_t kFirstEventId = kInvalidEventId / 2; namespace device { @@ -51,7 +51,7 @@ class KernelAdjust { static KernelAdjust instance; return instance; } - void Reorder(const std::shared_ptr &kernel_graph); + void InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr); bool StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr); void Profiling(NotNull kernel_graph_ptr); @@ -65,7 +65,6 @@ class KernelAdjust { void ReorderGetNext(const std::shared_ptr &kernel_graph_ptr); CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); - uint32_t FindFirstStreamSwitchLabel(const std::shared_ptr &kernel_graph_ptr); void CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, std::map *switch_loop_input); CNodePtr CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, diff --git a/mindspore/ccsrc/device/kernel_info.h b/mindspore/ccsrc/device/kernel_info.h index 33ddda83c9..84cfaa0fa3 100644 --- a/mindspore/ccsrc/device/kernel_info.h +++ b/mindspore/ccsrc/device/kernel_info.h @@ -35,7 +35,7 @@ class KernelInfo { select_kernel_build_info_ = nullptr; output_address_list_ = {}; workspace_address_list_ = {}; - stream_id_ = 0; + stream_id_ = UINT32_MAX; stream_distinction_label_ = kInvalidDistincLabel; graph_id_ = kInvalidGraphId; } diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 043b0b8c83..aae21aac72 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -283,18 +283,37 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(mem_manager_); auto graph_inputs = graph->inputs(); auto graph_valid_input = graph->valid_inputs(); - for (size_t i = 0; i < graph_inputs.size(); i++) { + std::vector need_alloc_nodes; + for (size_t i = 0; i < graph_inputs.size(); ++i) { auto item = graph_inputs[i]; MS_EXCEPTION_IF_NULL(item); - if (!item->isa()) { + if (i < graph_valid_input.size() && !graph_valid_input[i]) { continue; } - if (i < graph_valid_input.size() && !graph_valid_input[i]) { + + if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) { + auto outs = AnfAlgo::GetAllOutput(item); + for (auto &out : outs) { + MS_EXCEPTION_IF_NULL(out); + if (!out->isa()) { + continue; + } + if (NodeOutputDeviceAddressExist(out, 0)) { + continue; + } + need_alloc_nodes.push_back(out); + } + } + if (!item->isa()) { continue; } if (NodeOutputDeviceAddressExist(item, 0)) { continue; } + need_alloc_nodes.push_back(item); + } + + for (auto &item : need_alloc_nodes) { auto output_size = AnfAlgo::GetOutputTensorNum(item); for (size_t index = 0; index < output_size; index++) { TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); diff --git a/mindspore/ccsrc/kernel/rts/label_switch.cc b/mindspore/ccsrc/kernel/rts/label_switch.cc index 256ab70710..d84407a930 100644 --- a/mindspore/ccsrc/kernel/rts/label_switch.cc +++ b/mindspore/ccsrc/kernel/rts/label_switch.cc @@ -75,7 +75,6 @@ std::vector LabelSwitchKernel::GenTask(const std::vector> LabelSwitchDesc::GetKernelInfo() { std::vector> label_switch_build_info{}; - vector input_format{kOpFormat_DEFAULT, kOpFormat_DEFAULT}; vector input_type{kNumberTypeUInt32, kNumberTypeBool}; if (input_format.size() != input_type.size()) { diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index f127305d1b..01098a3e14 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -281,8 +281,12 @@ bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPa bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } -static bool IsCtrlSink() { +static bool IsCtrlSink(const FuncGraphPtr &graph) { auto ms_ctx = MsContext::GetInstance(); + if (ms_ctx->execution_mode() != kGraphMode) { + return false; + } + std::string device_target = ms_ctx->device_target(); if (device_target != kAscendDevice) { return false; @@ -292,12 +296,7 @@ static bool IsCtrlSink() { return false; } - const char *enable_ctrl_sink = std::getenv("ENABLE_CTRL_SINK"); - if (enable_ctrl_sink == nullptr) { - return false; - } - std::string enable_ctrl_sink_str(enable_ctrl_sink); - if (enable_ctrl_sink_str == "0") { + if (graph != nullptr && CompileGraphs::ContainMixedTarget(graph)) { return false; } @@ -310,7 +309,7 @@ bool TaskEmitAction(const ResourcePtr &res) { } FuncGraphPtr func_graph = res->func_graph(); auto bc_ptr = res->results()[kBackend].cast(); - if (IsCtrlSink()) { + if (IsCtrlSink(func_graph)) { res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); return true; } @@ -322,7 +321,7 @@ bool TaskEmitAction(const ResourcePtr &res) { std::shared_ptr compile = std::make_shared(bc_ptr, cut_list); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (compile->ContainMixedTarget(func_graph)) { + if (CompileGraphs::ContainMixedTarget(func_graph)) { bc_ptr->set_is_multi_graph_sink(false); context_ptr->set_loop_sink_flag(false); } else if (context_ptr->execution_mode() != kPynativeMode) { @@ -340,7 +339,7 @@ bool ExecuteAction(const ResourcePtr &res) { MS_LOG(EXCEPTION) << "Execute args error"; } - if (IsCtrlSink()) { + if (IsCtrlSink(nullptr)) { if (!res->results()[kOutput].is()) { MS_LOG(EXCEPTION) << "Execute args error"; } diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 836110b8a4..1ec11d50db 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -996,5 +996,23 @@ bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { } return shape.size() == kShape1dDims && shape[0] == 1; } + +void AnfRuntimeAlgorithm::ReorderExecList(NotNull *> node_list) { + std::vector all_opt_list; + std::vector non_opt_list; + + for (const auto &node : *node_list) { + MS_EXCEPTION_IF_NULL(node); + if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) { + all_opt_list.emplace_back(node); + } else { + non_opt_list.emplace_back(node); + } + } + node_list->clear(); + std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list)); + std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); +} + } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 223917ceec..cd14a8b20d 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -189,6 +189,7 @@ class AnfRuntimeAlgorithm { static bool IsSwitchCall(const CNodePtr &call_node); static bool IsScalarInput(const CNodePtr &cnode, size_t index); static bool IsScalarOutput(const CNodePtr &cnode, size_t index); + static void ReorderExecList(NotNull *> node_list); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 02217291f3..1bdc3876c1 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -40,7 +40,7 @@ static void InitUnionFindSet(NotNull kg, const NotNullinsert(kg.get()); - const std::map> &real_inputs = kg->real_inputs(); + const std::vector>> &real_inputs = kg->real_inputs(); for (auto &iter : real_inputs) { auto ¶ = iter.first; if (para->isa()) { @@ -65,7 +65,7 @@ static void UnionParentParameter(NotNull kg, const NotNullinsert(kg.get()); - const std::map> &real_inputs = kg->real_inputs(); + const std::vector>> &real_inputs = kg->real_inputs(); for (auto &iter : real_inputs) { auto ¶ = iter.first; for (auto &arg : iter.second) { @@ -174,16 +174,18 @@ void AscendControlParser::ChildGraphDataAssign(const std::map> &real_inputs = kg->real_inputs(); - for (auto &in : kg->inputs()) { - auto it = real_inputs.find(in); - if (it == real_inputs.end()) { - continue; - } - auto ¶meter = it->first; - auto &args = it->second; + std::set> memo; + const std::vector>> &real_inputs = kg->real_inputs(); + for (auto &it : real_inputs) { + auto ¶meter = it.first; + auto &args = it.second; for (auto &arg : args) { MS_EXCEPTION_IF_NULL(arg); + if (memo.find({parameter, arg}) != memo.end()) { + continue; + } else { + memo.emplace(parameter, arg); + } if (arg->isa()) { MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() << ", arg:" << arg->DebugString(); @@ -193,7 +195,7 @@ void AscendControlParser::ChildGraphDataAssign(const std::mapsecond), NOT_NULL(arg), NOT_NULL(parameter)); + InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter)); } } } @@ -285,17 +287,8 @@ void AscendControlParser::InsertControlDependToGraph(NotNull kg, void AscendControlParser::LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, const CNodePtr &last_label) { - auto origin_return = kg->get_return(); - const std::vector &origin_return_inputs = origin_return->inputs(); - // if entry graph, replace return with make_tuple - if (from_graph_call_node == nullptr || last_label == nullptr) { - MS_LOG(INFO) << kg->ToString() << " is entry graph."; - std::vector make_tuple_inputs = {std::make_shared(prim::kPrimMakeTuple)}; - make_tuple_inputs.insert(make_tuple_inputs.end(), origin_return_inputs.begin() + 1, origin_return_inputs.end()); - auto make_tuple = kg->NewCNode(make_tuple_inputs); - origin_return->set_inputs({origin_return->input(kCNodePrim), make_tuple}); - } else { - // else replace return with label_goto + // if not entry graph, replace return with label_goto + if (from_graph_call_node != nullptr && last_label != nullptr) { auto label_goto = kg->NewCNode({std::make_shared(std::make_shared(kLabelGotoOpName)), last_label}); MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString(); @@ -428,6 +421,20 @@ std::tuple AscendControlParser::ParsePartial(NotNull kg, NotNull from, + NotNull to) { + std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); + std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); + MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]"; + if (from_outputs.size() != to_outputs.size()) { + MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size[" + << to_outputs.size() << "]"; + } + for (size_t i = 0; i < from_outputs.size(); i++) { + InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); + } +} + void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, NotNull to) { if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && @@ -457,7 +464,16 @@ std::vector AscendControlParser::RecurseGraph(NotNull } memo->insert(graph.get()); graph->SetExecOrderByDefault(); - const std::vector &cnodes = graph->execution_order(); + std::vector cnodes = graph->execution_order(); + + auto end_label_goto = graph->get_end_goto(); + if (cnodes.rbegin() != cnodes.rend() && *cnodes.rbegin() == end_label_goto) { + cnodes.pop_back(); + } + AnfAlgo::ReorderExecList(NOT_NULL(&cnodes)); + if (end_label_goto != nullptr) { + cnodes.push_back(end_label_goto); + } std::vector execution_order; uint32_t child_order_index = 0; diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index 05f5e19729..73d68449b3 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -52,6 +52,7 @@ class AscendControlParser { const CNodePtr &last_label); static std::tuple ParsePartial(NotNull node); + static void InsertMultipleAssignToGraph(NotNull kg, NotNull from, NotNull to); static void InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); // root graph order diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index ce421d03c7..fd3fc5bf64 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -545,7 +545,6 @@ void AscendSession::HardwareOptimize(const std::shared_ptr &kernel_ void AscendSession::AdjustKernel(const std::shared_ptr &kernel_graph) const { MS_LOG(INFO) << "Start!"; - device::KernelAdjust::GetInstance().Reorder(kernel_graph); opt::HideNopNode(kernel_graph.get()); // Insert CLearZero op // prepare for next step from json get atomic info @@ -578,7 +577,7 @@ void AscendSession::RunOpAdjustKernel(const std::shared_ptr &kernel void AscendSession::AssignStream(const std::shared_ptr &kernel_graph) const { MS_LOG(INFO) << "Start!"; - device::ascend::AscendStreamAssign::GetInstance().AssignStreamNew(kernel_graph); + device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph); MS_LOG(INFO) << "Finish!"; } @@ -1512,6 +1511,11 @@ void AscendSession::SplitGraphs(NotNull root_graph) { RecurseSplitGraph(root_graph, NOT_NULL(&memo)); } memo.clear(); + // add maketuple to the end of the last child graph to suit old process + auto output_graph = root_graph->child_graph_order().empty() ? root_graph : root_graph->child_graph_order().back(); + auto make_tuple = output_graph->NewCNode( + {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), output_graph->output()}); + output_graph->set_output(make_tuple); // replace the real input if the real input is a call RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo)); } diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 9f3b5bbac4..6bc0ec8677 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -43,12 +43,28 @@ void PushNoVisitedNode(const AnfNodePtr &node, std::queue *que, std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { auto item_with_index = AnfAlgo::VisitKernelWithReturnType(call_node, 0); - MS_EXCEPTION_IF_NULL(item_with_index.first); - if (!AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) { - return {item_with_index.first}; + AnfNodePtr node = item_with_index.first; + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { + auto outputs = AnfAlgo::GetAllOutput(node); + std::set memo; + std::vector new_output; + for (auto &output : outputs) { + if (memo.find(output) != memo.end()) { + continue; + } + memo.insert(output); + new_output.push_back(output); + } + if (new_output.size() == 1 && AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) { + node = new_output[0]; + } + } + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { + return {node}; } std::vector real_inputs; - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast()); + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast()); for (const auto &child_graph : child_graphs) { if (child_graph->get_output_null()) { continue; @@ -623,18 +639,25 @@ void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull> &n) -> bool { + return n.first == old_anf_node.get(); + }); if (it_real_inputs != real_inputs_.end()) { + // erase old parameter in map + auto old_args = it_real_inputs->second; + real_inputs_.erase(it_real_inputs); // insert new parameter to map - auto iter = real_inputs_.find(new_anf_node); + auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(), + [&new_anf_node](const std::pair> &n) -> bool { + return n.first == new_anf_node.get(); + }); if (iter != real_inputs_.end()) { MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited."; - iter->second = it_real_inputs->second; + iter->second = old_args; } else { - real_inputs_[new_anf_node.get()] = it_real_inputs->second; + real_inputs_.emplace_back(new_anf_node, old_args); } - // erase old parameter in map - real_inputs_.erase(old_anf_node); } } @@ -676,57 +699,33 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar MS_LOG(INFO) << "parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString(); MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(arg); - if (real_inputs_.find(parameter) == real_inputs_.end()) { - real_inputs_[parameter] = std::vector(); - } - auto &args = real_inputs_[parameter]; - (void)args.push_back(arg); -} - -std::vector KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { - MS_EXCEPTION_IF_NULL(parameter); - auto iter = real_inputs_.find(parameter); + auto iter = std::find_if( + real_inputs_.begin(), real_inputs_.end(), + [¶meter](const std::pair> &n) -> bool { return n.first == parameter; }); if (iter != real_inputs_.end()) { - return iter->second; + auto &args = iter->second; + args.push_back(arg); + } else { + real_inputs_.emplace_back(parameter, std::vector(1, arg)); } - MS_LOG(EXCEPTION) << parameter->DebugString() << " not found."; } void KernelGraph::UpdateCallRealInput() { MS_LOG(INFO) << "Update graph id: " << graph_id_; - std::map> real_inputs_map; + std::vector>> real_inputs_map; for (auto &it : real_inputs_) { auto parameter = it.first; MS_EXCEPTION_IF_NULL(parameter); auto real_inputs = it.second; std::vector new_real_inputs; - std::set erase_real_inputs; for (auto &real_input : real_inputs) { // if real input is a call node ,find the child graph output act as the new real input auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0); MS_EXCEPTION_IF_NULL(item_with_index.first); - if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) { - (void)erase_real_inputs.insert(item_with_index.first); - new_real_inputs = GetCallRealOutputs(item_with_index.first); - continue; - } - } - for (auto &erase_node : erase_real_inputs) { - MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString(); - for (auto iter = real_inputs.begin(); iter != real_inputs.end();) { - if (*iter == erase_node) { - iter = real_inputs.erase(iter); - } else { - ++iter; - } - } - } - for (auto &new_real_input : new_real_inputs) { - MS_LOG(INFO) << "paramter: " << parameter->DebugString() - << " insert real input:" << new_real_input->DebugString(); - (void)real_inputs.push_back(new_real_input); + auto tmp_real_input = GetCallRealOutputs(item_with_index.first); + std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs)); } - real_inputs_map[parameter] = real_inputs; + real_inputs_map.emplace_back(parameter, new_real_inputs); } real_inputs_ = real_inputs_map; } diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index dbb79d561c..9954b5b1d0 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -127,8 +127,7 @@ class KernelGraph : public FuncGraph { // find anf node in graph std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; // get real inputs - const std::map> &real_inputs() const { return real_inputs_; } - std::vector GetRealInput(const AnfNodePtr ¶meter); + const std::vector>> &real_inputs() const { return real_inputs_; } void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); // used to dump ir std::string ToString() const override; @@ -197,7 +196,7 @@ class KernelGraph : public FuncGraph { // parameter graph std::shared_ptr parent_graph_; // record real parameters,inputs_ is the formal parameters - std::map> real_inputs_; + std::vector>> real_inputs_; CNodePtr start_label_; CNodePtr end_goto_; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index ae2e085800..0ff75878aa 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -727,23 +727,7 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { summary_callback_ = callback; } -void SessionBasic::Reorder(std::vector *node_list) { - MS_EXCEPTION_IF_NULL(node_list); - std::vector all_opt_list; - std::vector non_opt_list; - - for (const auto &node : *node_list) { - MS_EXCEPTION_IF_NULL(node); - if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) { - all_opt_list.emplace_back(node); - } else { - non_opt_list.emplace_back(node); - } - } - node_list->clear(); - (void)std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list)); - (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); -} +void SessionBasic::Reorder(std::vector *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } void SessionBasic::GetSummaryNodes(KernelGraph *graph) { MS_LOG(DEBUG) << "Update summary Start"; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index e33a68e1c5..91aa974cdf 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -857,6 +857,7 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { } bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); auto graph_manager = graph->manager(); MS_EXCEPTION_IF_NULL(graph_manager); FuncGraphSet graphs = graph_manager->func_graphs(); diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index 3a1da0ff42..069a97a234 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -124,7 +124,7 @@ class CompileGraphs { void Compile(const FuncGraphPtr &func_graph); FinalVMPtr Link(const FuncGraphPtr &func_graph); FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); - bool ContainMixedTarget(const FuncGraphPtr &graph); + static bool ContainMixedTarget(const FuncGraphPtr &graph); private: InstSet insts_; diff --git a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc index 5d8e33b256..fba52323cf 100755 --- a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc +++ b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc @@ -26,12 +26,12 @@ void AscendLabelAssign::AssignLabel(NotNull graph) { return 1; } uint32_t AscendLabelAssign::GetLabelNum(NotNull> graph) { return 1; } -void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; } - -uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; } +void AscendStreamAssign::AssignStream(const KernelGraphPtr &graph) { return; } void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { return; } +void AscendStreamAssign::GetHcomStreams(std::vector *streams) { return; } + namespace tasksink { bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::vector *const task_info_list, uint32_t graph_id) { @@ -39,7 +39,6 @@ bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::ve } } // namespace tasksink } // namespace ascend -void KernelAdjust::Reorder(const std::shared_ptr &kernel_graph_ptr) { return; } void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { return; } bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr) { return true; } bool KernelAdjust::NeedInsertSwitch() { return true; }