| @@ -283,18 +283,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { | |||||
| AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); | AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); | ||||
| // the streams' flag not HEAD_STREAM | // the streams' flag not HEAD_STREAM | ||||
| std::vector<uint32_t> wait_active_stream_list = assign_instance.GetWaitStreams(); | |||||
| std::vector<uint32_t> force_copy_stream_list = assign_instance.GetHcomStreams(); | |||||
| std::vector<uint32_t> wait_active_stream_list; | |||||
| assign_instance.GetWaitStreams(&wait_active_stream_list); | |||||
| auto force_copy_stream_list = assign_instance.hcom_streams(); | |||||
| MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum() | MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum() | ||||
| << ", total event num:" << assign_instance.GetTotalEventNum() | |||||
| << ", total event num:" << assign_instance.total_event_num() | |||||
| << ", wait_active_stream_list size:" << wait_active_stream_list.size() | << ", wait_active_stream_list size:" << wait_active_stream_list.size() | ||||
| << ", force_copy_stream_list size:" << force_copy_stream_list.size(); | << ", force_copy_stream_list size:" << force_copy_stream_list.size(); | ||||
| std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list; | std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list; | ||||
| std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>( | std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>( | ||||
| task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, | task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, | ||||
| 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.GetTotalEventNum(), 0); | |||||
| 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.total_event_num(), 0); | |||||
| auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); | auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); | ||||
| if (!ret.second) { | if (!ret.second) { | ||||
| @@ -25,8 +25,8 @@ | |||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "device/kernel_adjust.h" | #include "device/kernel_adjust.h" | ||||
| #include "predict/generator/utils/ir_model_util.h" | #include "predict/generator/utils/ir_model_util.h" | ||||
| #include "device/kernel_info.h" | |||||
| #include "pre_activate/common/helper.h" | #include "pre_activate/common/helper.h" | ||||
| #include "utils/utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| @@ -54,6 +54,7 @@ void AscendStreamAssign::ResetNew() { | |||||
| inner_parallel_streams_.clear(); | inner_parallel_streams_.clear(); | ||||
| processed_parallel_streams_.clear(); | processed_parallel_streams_.clear(); | ||||
| hcom_stream_list_.clear(); | hcom_stream_list_.clear(); | ||||
| need_first_active_streams_.clear(); | |||||
| } | } | ||||
| void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t processing_logic_id) { | void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t processing_logic_id) { | ||||
| @@ -200,13 +201,12 @@ void AscendStreamAssign::AssignAllNodesStream(const shared_ptr<session::KernelGr | |||||
| MS_LOG(INFO) << "stream nums:common:" << total_common_stream_num_ << ",independ:" << total_independ_stream_num_; | MS_LOG(INFO) << "stream nums:common:" << total_common_stream_num_ << ",independ:" << total_independ_stream_num_; | ||||
| } | } | ||||
| vector<uint32_t> AscendStreamAssign::TransLogicToPhysic(const vector<uint32_t> &logic_ids) { | |||||
| vector<uint32_t> physic_ids; | |||||
| void AscendStreamAssign::TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids) { | |||||
| for (auto &id : logic_ids) { | for (auto &id : logic_ids) { | ||||
| auto it = logic_to_physic_map_.find(id); | auto it = logic_to_physic_map_.find(id); | ||||
| if (it != logic_to_physic_map_.end()) { | if (it != logic_to_physic_map_.end()) { | ||||
| MS_LOG(INFO) << "logic id[" << id << "] to physic id[" << it->second << "]"; | MS_LOG(INFO) << "logic id[" << id << "] to physic id[" << it->second << "]"; | ||||
| physic_ids.push_back(it->second); | |||||
| (*physic_ids).push_back(it->second); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "logic id[" << id << "] has no correspond physic id"; | MS_LOG(EXCEPTION) << "logic id[" << id << "] has no correspond physic id"; | ||||
| } | } | ||||
| @@ -214,10 +214,9 @@ vector<uint32_t> AscendStreamAssign::TransLogicToPhysic(const vector<uint32_t> & | |||||
| auto it_independ = logic_to_independent_map_.find(id); | auto it_independ = logic_to_independent_map_.find(id); | ||||
| if (it_independ != logic_to_independent_map_.end()) { | if (it_independ != logic_to_independent_map_.end()) { | ||||
| MS_LOG(INFO) << "logic id[" << id << "] to independent id[" << it_independ->second << "]"; | MS_LOG(INFO) << "logic id[" << id << "] to independent id[" << it_independ->second << "]"; | ||||
| physic_ids.push_back(it_independ->second); | |||||
| (*physic_ids).push_back(it_independ->second); | |||||
| } | } | ||||
| } | } | ||||
| return physic_ids; | |||||
| } | } | ||||
| void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { | void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { | ||||
| @@ -227,7 +226,8 @@ void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| vector<uint32_t> active_logic_ids = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrActiveStreamList)); | vector<uint32_t> active_logic_ids = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrActiveStreamList)); | ||||
| // out StreamAcitve active physic stream is not parallel now, if parallel, should deal here. | // out StreamAcitve active physic stream is not parallel now, if parallel, should deal here. | ||||
| vector<uint32_t> active_physic_ids = TransLogicToPhysic(active_logic_ids); | |||||
| vector<uint32_t> active_physic_ids; | |||||
| TransLogicToPhysic(active_logic_ids, &active_physic_ids); | |||||
| ValuePtr active_physic_value = MakeValue<std::vector<uint32_t>>(active_physic_ids); | ValuePtr active_physic_value = MakeValue<std::vector<uint32_t>>(active_physic_ids); | ||||
| AnfAlgo::SetNodeAttr(kAttrActiveStreamList, active_physic_value, active_ptr); | AnfAlgo::SetNodeAttr(kAttrActiveStreamList, active_physic_value, active_ptr); | ||||
| } | } | ||||
| @@ -242,7 +242,8 @@ void AscendStreamAssign::UpdateStreamSwitch(const CNodePtr &switch_ptr, const CN | |||||
| MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_logic_id | MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_logic_id | ||||
| << "]"; | << "]"; | ||||
| vector<uint32_t> logic_ids{true_logic_id}; | vector<uint32_t> logic_ids{true_logic_id}; | ||||
| vector<uint32_t> physic_ids = TransLogicToPhysic(logic_ids); | |||||
| vector<uint32_t> physic_ids; | |||||
| TransLogicToPhysic(logic_ids, &physic_ids); | |||||
| if (physic_ids.empty()) { | if (physic_ids.empty()) { | ||||
| MS_LOG(EXCEPTION) << "stream switch true logic id[" << true_logic_id << "] has no physical id"; | MS_LOG(EXCEPTION) << "stream switch true logic id[" << true_logic_id << "] has no physical id"; | ||||
| } | } | ||||
| @@ -334,8 +335,8 @@ bool AscendStreamAssign::IsProcessedParallelStream(uint32_t stream_id) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| vector<uint32_t> AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id) { | |||||
| vector<uint32_t> parallel_streams; | |||||
| void AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, | |||||
| vector<uint32_t> *parallel_streams) { | |||||
| for (size_t i = 0; i < inner_parallel_streams_.size(); i++) { | for (size_t i = 0; i < inner_parallel_streams_.size(); i++) { | ||||
| auto cur_parallel_streams = inner_parallel_streams_[i]; | auto cur_parallel_streams = inner_parallel_streams_[i]; | ||||
| auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id); | auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id); | ||||
| @@ -347,17 +348,17 @@ vector<uint32_t> AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, u | |||||
| << "is same with streamacvite stream id" << stream_acitve_id; | << "is same with streamacvite stream id" << stream_acitve_id; | ||||
| continue; | continue; | ||||
| } | } | ||||
| parallel_streams.emplace_back(cur_parallel_streams[j]); | |||||
| (*parallel_streams).emplace_back(cur_parallel_streams[j]); | |||||
| } | } | ||||
| // record processed parallel streams | // record processed parallel streams | ||||
| (void)std::copy(parallel_streams.begin(), parallel_streams.end(), | |||||
| (void)std::copy((*parallel_streams).begin(), (*parallel_streams).end(), | |||||
| std::back_inserter(processed_parallel_streams_)); | std::back_inserter(processed_parallel_streams_)); | ||||
| return parallel_streams; | |||||
| return; | |||||
| } | } | ||||
| } | } | ||||
| return vector<uint32_t>{cur_stream_id}; | |||||
| (*parallel_streams).push_back(cur_stream_id); | |||||
| } | } | ||||
| void AscendStreamAssign::InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr) { | void AscendStreamAssign::InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr) { | ||||
| @@ -379,30 +380,32 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr<session::KernelGr | |||||
| } | } | ||||
| bool inner_active = pre_stream_id != cur_stream_id && pre_stream_id < cur_stream_id && | bool inner_active = pre_stream_id != cur_stream_id && pre_stream_id < cur_stream_id && | ||||
| AnfAlgo::GetCNodeName(pre_cnode_ptr) != "StreamSwitch" && | |||||
| AnfAlgo::GetCNodeName(pre_cnode_ptr) != "StreamActive"; | |||||
| AnfAlgo::GetCNodeName(pre_cnode_ptr) != kStreamSwitchOpName && | |||||
| AnfAlgo::GetCNodeName(pre_cnode_ptr) != kStreamActiveOpName && | |||||
| AnfAlgo::GetCNodeName(pre_cnode_ptr) != kSendOpName; | |||||
| bool processed = IsProcessedParallelStream(cur_stream_id); | bool processed = IsProcessedParallelStream(cur_stream_id); | ||||
| // 1)inner stream assign, need insert active op | // 1)inner stream assign, need insert active op | ||||
| if (inner_active && !processed) { | if (inner_active && !processed) { | ||||
| MS_LOG(INFO) << "Inner insert active op, self stream id[" << pre_stream_id << "]"; | MS_LOG(INFO) << "Inner insert active op, self stream id[" << pre_stream_id << "]"; | ||||
| CNodePtr active_ptr = KernelAdjust::GetInstance().CreateSteamActiveOp(graph_ptr); | |||||
| CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); | |||||
| update_cnode_list.emplace_back(active_ptr); | update_cnode_list.emplace_back(active_ptr); | ||||
| update_cnode_list.emplace_back(cur_cnode_ptr); | |||||
| // 1.set stream id | // 1.set stream id | ||||
| AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); | AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); | ||||
| // 2.set active stream ids | // 2.set active stream ids | ||||
| vector<uint32_t> active_index_list = GetParallelStream(cur_stream_id, pre_stream_id); | |||||
| std::vector<uint32_t> active_index_list; | |||||
| GetParallelStream(cur_stream_id, pre_stream_id, &active_index_list); | |||||
| AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr); | AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr); | ||||
| } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamActive" && | |||||
| AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) { | |||||
| } | |||||
| // inner_active is not a if/else relationship with the next if/else. such as:StreamActive(S7)-->StreamActive(S8) | |||||
| if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName && | |||||
| AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) { | |||||
| // 2)outter stream assign, update active op | // 2)outter stream assign, update active op | ||||
| update_cnode_list.emplace_back(cur_cnode_ptr); | update_cnode_list.emplace_back(cur_cnode_ptr); | ||||
| UpdateStreamActive(cur_cnode_ptr); | UpdateStreamActive(cur_cnode_ptr); | ||||
| } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamSwitch") { | |||||
| } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { | |||||
| // 3)update switch op | // 3)update switch op | ||||
| MS_LOG(INFO) << "Insert active op after switch"; | MS_LOG(INFO) << "Insert active op after switch"; | ||||
| CNodePtr active_ptr = KernelAdjust::GetInstance().CreateSteamActiveOp(graph_ptr); | |||||
| CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); | |||||
| update_cnode_list.emplace_back(cur_cnode_ptr); | update_cnode_list.emplace_back(cur_cnode_ptr); | ||||
| update_cnode_list.emplace_back(active_ptr); | update_cnode_list.emplace_back(active_ptr); | ||||
| UpdateStreamSwitch(cur_cnode_ptr, active_ptr); | UpdateStreamSwitch(cur_cnode_ptr, active_ptr); | ||||
| @@ -417,6 +420,37 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr<session::KernelGr | |||||
| MS_LOG(INFO) << "end"; | MS_LOG(INFO) << "end"; | ||||
| } | } | ||||
| void AscendStreamAssign::UpdateEventId(const shared_ptr<session::KernelGraph> &graph_ptr) { | |||||
| MS_LOG(INFO) << "start"; | |||||
| MS_EXCEPTION_IF_NULL(graph_ptr); | |||||
| CNodePtr cur_cnode_ptr = nullptr; | |||||
| // key:virutal event id, value:real event id | |||||
| std::unordered_map<uint32_t, uint32_t> event_id_map; | |||||
| uint32_t event_id; | |||||
| auto cnode_ptr_list = graph_ptr->execution_order(); | |||||
| for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | |||||
| cur_cnode_ptr = cnode_ptr_list[i]; | |||||
| MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | |||||
| if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| event_id = GetValue<uint32_t>(primitive->GetAttr(kAttrEventId)); | |||||
| // before stream assign, send/recv event_id assign from kFirstEventId | |||||
| if (event_id < kFirstEventId) { | |||||
| continue; | |||||
| } | |||||
| auto it = event_id_map.find(event_id); | |||||
| if (it == event_id_map.end()) { | |||||
| event_id_map.insert(std::make_pair(event_id, total_event_num_)); | |||||
| AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue<uint32_t>(total_event_num_), cur_cnode_ptr); | |||||
| total_event_num_++; | |||||
| } else { | |||||
| AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue<uint32_t>(it->second), cur_cnode_ptr); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void AscendStreamAssign::UpdateStreamId(const shared_ptr<session::KernelGraph> &graph_ptr) { | void AscendStreamAssign::UpdateStreamId(const shared_ptr<session::KernelGraph> &graph_ptr) { | ||||
| MS_LOG(INFO) << "start"; | MS_LOG(INFO) << "start"; | ||||
| MS_EXCEPTION_IF_NULL(graph_ptr); | MS_EXCEPTION_IF_NULL(graph_ptr); | ||||
| @@ -427,7 +461,7 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr<session::KernelGraph> & | |||||
| MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | ||||
| uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); | uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); | ||||
| if (cur_stream_id < kIndependFirstStreamId) { | if (cur_stream_id < kIndependFirstStreamId) { | ||||
| if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamActive") { | |||||
| if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName) { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); | auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); | ||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| vector<uint32_t> active_ids = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrActiveStreamList)); | vector<uint32_t> active_ids = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrActiveStreamList)); | ||||
| @@ -471,6 +505,29 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr<session::KernelGraph> & | |||||
| MS_LOG(INFO) << "end"; | MS_LOG(INFO) << "end"; | ||||
| } | } | ||||
| void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGraph> &graph_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(graph_ptr); | |||||
| CNodePtr cur_cnode_ptr = nullptr; | |||||
| auto cnode_ptr_list = graph_ptr->execution_order(); | |||||
| for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | |||||
| cur_cnode_ptr = cnode_ptr_list[i]; | |||||
| MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); | |||||
| MS_EXCEPTION_IF_NULL(primitive); | |||||
| auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); | |||||
| if (value_ptr == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto need_active = GetValue<bool>(value_ptr); | |||||
| if (need_active) { | |||||
| auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); | |||||
| MS_LOG(INFO) << "stream id:" << stream_id << " is need actived at first"; | |||||
| need_first_active_streams_.push_back(stream_id); | |||||
| } | |||||
| } | |||||
| } | |||||
| void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> &graph_ptr) { | void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> &graph_ptr) { | ||||
| if (IsTaskSink()) { | if (IsTaskSink()) { | ||||
| ResetNew(); | ResetNew(); | ||||
| @@ -480,13 +537,15 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> | |||||
| InsertSendRecvForHcomParallel(graph_ptr); | InsertSendRecvForHcomParallel(graph_ptr); | ||||
| InsertSendRecvForIndependent(graph_ptr); | InsertSendRecvForIndependent(graph_ptr); | ||||
| UpdateStreamId(graph_ptr); | UpdateStreamId(graph_ptr); | ||||
| UpdateEventId(graph_ptr); | |||||
| GetNeedActiveStreams(graph_ptr); | |||||
| MS_LOG(INFO) << "after finish stream assign"; | MS_LOG(INFO) << "after finish stream assign"; | ||||
| PrintGraphExeOrders(graph_ptr); | PrintGraphExeOrders(graph_ptr); | ||||
| // Get info for D Model | // Get info for D Model | ||||
| generator::IRModelUtil::GetInstance().set_event_num(GetTotalEventNum()); | |||||
| generator::IRModelUtil::GetInstance().set_stream_num(GetTotalCommonStreamNum() + GetTotalIndependStreamNum()); | |||||
| generator::IRModelUtil::GetInstance().set_event_num(total_event_num()); | |||||
| generator::IRModelUtil::GetInstance().set_stream_num(total_common_stream_num() + total_independ_stream_num()); | |||||
| // Init to 1,temporarily | // Init to 1,temporarily | ||||
| generator::IRModelUtil::GetInstance().set_batch_num(1); | generator::IRModelUtil::GetInstance().set_batch_num(1); | ||||
| } | } | ||||
| @@ -495,7 +554,7 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> | |||||
| CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, | CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, | ||||
| uint32_t event_id, uint32_t stream_id) { | uint32_t event_id, uint32_t stream_id) { | ||||
| MS_EXCEPTION_IF_NULL(graph_ptr); | MS_EXCEPTION_IF_NULL(graph_ptr); | ||||
| auto send_op = std::make_shared<Primitive>("Send"); | |||||
| auto send_op = std::make_shared<Primitive>(kSendOpName); | |||||
| MS_EXCEPTION_IF_NULL(send_op); | MS_EXCEPTION_IF_NULL(send_op); | ||||
| auto send_apply = std::make_shared<ValueNode>(send_op); | auto send_apply = std::make_shared<ValueNode>(send_op); | ||||
| MS_EXCEPTION_IF_NULL(send_apply); | MS_EXCEPTION_IF_NULL(send_apply); | ||||
| @@ -505,7 +564,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr<session | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; | kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; | ||||
| selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); | selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); | AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); | ||||
| AnfAlgo::SetNodeAttr("event_id", MakeValue(event_id), send_node_ptr); | |||||
| AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); | |||||
| auto abstract_none = std::make_shared<abstract::AbstractNone>(); | auto abstract_none = std::make_shared<abstract::AbstractNone>(); | ||||
| MS_EXCEPTION_IF_NULL(abstract_none); | MS_EXCEPTION_IF_NULL(abstract_none); | ||||
| send_node_ptr->set_abstract(abstract_none); | send_node_ptr->set_abstract(abstract_none); | ||||
| @@ -516,7 +575,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr<session | |||||
| CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, | CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, | ||||
| uint32_t event_id, uint32_t stream_id) { | uint32_t event_id, uint32_t stream_id) { | ||||
| MS_EXCEPTION_IF_NULL(graph_ptr); | MS_EXCEPTION_IF_NULL(graph_ptr); | ||||
| auto recv_op = std::make_shared<Primitive>("Recv"); | |||||
| auto recv_op = std::make_shared<Primitive>(kRecvOpName); | |||||
| MS_EXCEPTION_IF_NULL(recv_op); | MS_EXCEPTION_IF_NULL(recv_op); | ||||
| auto recv_apply = std::make_shared<ValueNode>(recv_op); | auto recv_apply = std::make_shared<ValueNode>(recv_op); | ||||
| MS_EXCEPTION_IF_NULL(recv_apply); | MS_EXCEPTION_IF_NULL(recv_apply); | ||||
| @@ -526,7 +585,7 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const std::shared_ptr<session | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; | kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; | ||||
| selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); | selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); | AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); | ||||
| AnfAlgo::SetNodeAttr("event_id", MakeValue(event_id), recv_node_ptr); | |||||
| AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); | |||||
| AnfAlgo::SetStreamId(stream_id, recv_node_ptr.get()); | AnfAlgo::SetStreamId(stream_id, recv_node_ptr.get()); | ||||
| auto abstract_none = std::make_shared<abstract::AbstractNone>(); | auto abstract_none = std::make_shared<abstract::AbstractNone>(); | ||||
| MS_EXCEPTION_IF_NULL(abstract_none); | MS_EXCEPTION_IF_NULL(abstract_none); | ||||
| @@ -605,7 +664,7 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (AnfAlgo::GetCNodeName(node_ptr) == "GetNext") { | |||||
| if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) { | |||||
| MS_LOG(INFO) << "GetNext should not be independent node"; | MS_LOG(INFO) << "GetNext should not be independent node"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -638,20 +697,23 @@ bool AscendStreamAssign::IsTaskSink() { | |||||
| } | } | ||||
| } | } | ||||
| std::vector<uint32_t> AscendStreamAssign::GetWaitStreams() { | |||||
| vector<uint32_t> wait_active_stream_list; | |||||
| void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) { | |||||
| if (total_common_stream_num_ == 0) { | if (total_common_stream_num_ == 0) { | ||||
| MS_LOG(INFO) << "total_common_stream_num is zero"; | MS_LOG(INFO) << "total_common_stream_num is zero"; | ||||
| return wait_active_stream_list; | |||||
| return; | |||||
| } | } | ||||
| // common stream:active first common stream | // common stream:active first common stream | ||||
| MS_LOG(INFO) << "active physic id[" << first_physic_id_ << "]"; | MS_LOG(INFO) << "active physic id[" << first_physic_id_ << "]"; | ||||
| for (uint32_t i = first_physic_id_ + 1; i < total_common_stream_num_; i++) { | for (uint32_t i = first_physic_id_ + 1; i < total_common_stream_num_; i++) { | ||||
| MS_LOG(INFO) << "wait common stream id = " << i; | |||||
| wait_active_stream_list.push_back(i); | |||||
| auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i); | |||||
| if (it == need_first_active_streams_.end()) { | |||||
| MS_LOG(INFO) << "wait common stream id = " << i; | |||||
| (*wait_active_stream_list).push_back(i); | |||||
| } | |||||
| } | } | ||||
| // all independ stream id before first physical stream id should be actived | |||||
| auto it = logic_to_independent_map_.find(first_logic_id_); | auto it = logic_to_independent_map_.find(first_logic_id_); | ||||
| if (it != logic_to_independent_map_.end()) { | if (it != logic_to_independent_map_.end()) { | ||||
| uint32_t independent_id = it->second; | uint32_t independent_id = it->second; | ||||
| @@ -675,16 +737,14 @@ std::vector<uint32_t> AscendStreamAssign::GetWaitStreams() { | |||||
| if (i + total_common_stream_num_ <= max_before_physic) { | if (i + total_common_stream_num_ <= max_before_physic) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_; | |||||
| wait_active_stream_list.push_back(i + total_common_stream_num_); | |||||
| // all wait streams should not in need_first_active_streams_ | |||||
| auto iter = | |||||
| std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i + total_common_stream_num_); | |||||
| if (iter == need_first_active_streams_.end()) { | |||||
| MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_; | |||||
| (*wait_active_stream_list).push_back(i + total_common_stream_num_); | |||||
| } | |||||
| } | } | ||||
| return wait_active_stream_list; | |||||
| } | |||||
| std::vector<uint32_t> AscendStreamAssign::GetHcomStreams() { | |||||
| MS_LOG(INFO) << "hcom total stream nums:" << hcom_stream_list_.size(); | |||||
| return hcom_stream_list_; | |||||
| } | } | ||||
| uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; } | uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; } | ||||
| @@ -695,7 +755,7 @@ void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr<mindspore::session | |||||
| for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { | ||||
| CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; | CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; | ||||
| MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | ||||
| if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "Send" || AnfAlgo::GetCNodeName(cur_cnode_ptr) == "Recv") { | |||||
| if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); | auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); | ||||
| MS_LOG(INFO) << "node name[" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "], logic id[" | MS_LOG(INFO) << "node name[" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "], logic id[" | ||||
| << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" | << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" | ||||
| @@ -49,37 +49,35 @@ class AscendStreamAssign { | |||||
| uint32_t GetTotalStreamNum() const; | uint32_t GetTotalStreamNum() const; | ||||
| // new stream policy | // new stream policy | ||||
| uint32_t GetTotalCommonStreamNum() const { return total_common_stream_num_; } | |||||
| uint32_t GetTotalIndependStreamNum() const { return total_independ_stream_num_; } | |||||
| uint32_t GetTotalEventNum() const { return total_event_num_; } | |||||
| const uint32_t GetFisrtPhysicId() const { return first_physic_id_; } | |||||
| const uint32_t GetFirstLogicId() const { return first_logic_id_; } | |||||
| uint32_t total_common_stream_num() const { return total_common_stream_num_; } | |||||
| uint32_t total_independ_stream_num() const { return total_independ_stream_num_; } | |||||
| uint32_t total_event_num() const { return total_event_num_; } | |||||
| void InsertActiveNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); | void InsertActiveNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); | ||||
| void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph>& graph_ptr); | void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph>& graph_ptr); | ||||
| void ResetNew(); | void ResetNew(); | ||||
| void AssignStreamNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); | void AssignStreamNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); | ||||
| bool IsIndependentNode(const CNodePtr& node_ptr); | bool IsIndependentNode(const CNodePtr& node_ptr); | ||||
| const std::unordered_map<uint32_t, uint32_t> GetIndependentMap() { return logic_to_independent_map_; } | |||||
| const std::unordered_map<uint32_t, uint32_t> GetPhysicMap() { return logic_to_physic_map_; } | |||||
| std::vector<uint32_t> GetWaitStreams(); | |||||
| std::vector<uint32_t> GetHcomStreams(); | |||||
| private: | |||||
| AscendStreamAssign() = default; | |||||
| ~AscendStreamAssign() = default; | |||||
| const std::unordered_map<uint32_t, uint32_t>& logic_to_independent_map() { return logic_to_independent_map_; } | |||||
| const std::unordered_map<uint32_t, uint32_t>& logic_to_physic_map() { return logic_to_physic_map_; } | |||||
| const std::vector<std::vector<uint32_t>>& inner_parallel_streams() { return inner_parallel_streams_; } | |||||
| void GetWaitStreams(vector<uint32_t>* wait_active_stream_list); | |||||
| const std::vector<uint32_t>& hcom_streams() { return hcom_stream_list_; } | |||||
| CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, | CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, | ||||
| uint32_t stream_id); | uint32_t stream_id); | ||||
| CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, | CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, | ||||
| uint32_t stream_id); | uint32_t stream_id); | ||||
| private: | |||||
| AscendStreamAssign() = default; | |||||
| ~AscendStreamAssign() = default; | |||||
| vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end, | vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end, | ||||
| const CNodePtr& node); | const CNodePtr& node); | ||||
| bool IsHcom(const CNodePtr& apply_kernel); | bool IsHcom(const CNodePtr& apply_kernel); | ||||
| bool IsProcessed(uint32_t logic_id); | bool IsProcessed(uint32_t logic_id); | ||||
| vector<uint32_t> TransLogicToPhysic(const vector<uint32_t>& logic_ids); | |||||
| void TransLogicToPhysic(const vector<uint32_t>& logic_ids, vector<uint32_t>* physic_ids); | |||||
| void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, | void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, | ||||
| uint32_t* cur_stream_id); | uint32_t* cur_stream_id); | ||||
| void RecordIdMap(uint32_t logic_id, uint32_t physic_id); | void RecordIdMap(uint32_t logic_id, uint32_t physic_id); | ||||
| @@ -88,15 +86,17 @@ class AscendStreamAssign { | |||||
| bool IsTaskSink(); | bool IsTaskSink(); | ||||
| void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); | void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); | ||||
| void UpdateStreamId(const std::shared_ptr<session::KernelGraph>& graph_ptr); | void UpdateStreamId(const std::shared_ptr<session::KernelGraph>& graph_ptr); | ||||
| void UpdateEventId(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph>& graph_ptr); | void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph>& graph_ptr); | ||||
| void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); | void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); | ||||
| uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); | uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); | ||||
| void SetCommonStreamNum(uint32_t cur_stream_id); | void SetCommonStreamNum(uint32_t cur_stream_id); | ||||
| void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); | void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); | ||||
| bool IsProcessedParallelStream(uint32_t stream_id); | bool IsProcessedParallelStream(uint32_t stream_id); | ||||
| vector<uint32_t> GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id); | |||||
| void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t>* parallel_streams); | |||||
| void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph>& graph_ptr); | void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph>& graph_ptr); | ||||
| void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); | void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); | ||||
| void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph>& graph_ptr); | |||||
| uint32_t total_common_stream_num_{0}; | uint32_t total_common_stream_num_{0}; | ||||
| uint32_t total_independ_stream_num_{0}; | uint32_t total_independ_stream_num_{0}; | ||||
| @@ -112,6 +112,7 @@ class AscendStreamAssign { | |||||
| std::vector<std::vector<uint32_t>> inner_parallel_streams_{}; | std::vector<std::vector<uint32_t>> inner_parallel_streams_{}; | ||||
| std::vector<uint32_t> processed_parallel_streams_{}; | std::vector<uint32_t> processed_parallel_streams_{}; | ||||
| std::vector<uint32_t> hcom_stream_list_{}; | std::vector<uint32_t> hcom_stream_list_{}; | ||||
| std::vector<uint32_t> need_first_active_streams_{}; | |||||
| // new policy end | // new policy end | ||||
| }; | }; | ||||
| } // namespace ascend | } // namespace ascend | ||||
| @@ -32,16 +32,8 @@ | |||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "device/ascend/profiling/profiling_manager.h" | #include "device/ascend/profiling/profiling_manager.h" | ||||
| #include "device/ascend/kernel_select_ascend.h" | #include "device/ascend/kernel_select_ascend.h" | ||||
| #include "device/kernel_info.h" | |||||
| #include "runtime/base.h" | #include "runtime/base.h" | ||||
| constexpr auto kLoopCountParamName = "loop_count"; | |||||
| constexpr auto kIterLoopParamName = "iter_loop"; | |||||
| constexpr auto kZeroParamName = "zero"; | |||||
| constexpr auto kOneParamName = "one"; | |||||
| constexpr auto kStreamSwitch = "StreamSwitch"; | |||||
| constexpr auto kStreamActive = "StreamActive"; | |||||
| constexpr auto kAssignAdd = "AssignAdd"; | |||||
| #include "device/ascend/ascend_stream_assign.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| using device::ascend::ProfilingUtils; | using device::ascend::ProfilingUtils; | ||||
| @@ -70,6 +62,63 @@ bool KernelAdjust::NeedInsertSwitch() { | |||||
| ConfigManager::GetInstance().iter_num() > 1); | ConfigManager::GetInstance().iter_num() > 1); | ||||
| } | } | ||||
| uint32_t KernelAdjust::FindFirstStreamSwitchLabel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | |||||
| auto cnode_ptr_list = kernel_graph_ptr->execution_order(); | |||||
| CNodePtr cur_cnode_ptr = nullptr; | |||||
| uint32_t label = kInvalidDistincLabel; | |||||
| for (uint32_t i = 0; i < cnode_ptr_list.size(); ++i) { | |||||
| cur_cnode_ptr = cnode_ptr_list[i]; | |||||
| MS_EXCEPTION_IF_NULL(cur_cnode_ptr); | |||||
| if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { | |||||
| label = AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()); | |||||
| break; | |||||
| } | |||||
| } | |||||
| return label; | |||||
| } | |||||
| CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, | |||||
| uint32_t event_id) { | |||||
| MS_EXCEPTION_IF_NULL(graph_ptr); | |||||
| auto send_op = std::make_shared<Primitive>(kSendOpName); | |||||
| MS_EXCEPTION_IF_NULL(send_op); | |||||
| auto send_apply = std::make_shared<ValueNode>(send_op); | |||||
| MS_EXCEPTION_IF_NULL(send_apply); | |||||
| std::vector<AnfNodePtr> send_input_list = {send_apply}; | |||||
| CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); | |||||
| MS_EXCEPTION_IF_NULL(send_node_ptr); | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; | |||||
| selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); | |||||
| auto abstract_none = std::make_shared<abstract::AbstractNone>(); | |||||
| MS_EXCEPTION_IF_NULL(abstract_none); | |||||
| send_node_ptr->set_abstract(abstract_none); | |||||
| return send_node_ptr; | |||||
| } | |||||
| CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, | |||||
| uint32_t event_id) { | |||||
| MS_EXCEPTION_IF_NULL(graph_ptr); | |||||
| auto recv_op = std::make_shared<Primitive>(kRecvOpName); | |||||
| MS_EXCEPTION_IF_NULL(recv_op); | |||||
| auto recv_apply = std::make_shared<ValueNode>(recv_op); | |||||
| MS_EXCEPTION_IF_NULL(recv_apply); | |||||
| std::vector<AnfNodePtr> recv_input_list = {recv_apply}; | |||||
| CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); | |||||
| MS_EXCEPTION_IF_NULL(recv_node_ptr); | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; | |||||
| selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); | |||||
| auto abstract_none = std::make_shared<abstract::AbstractNone>(); | |||||
| MS_EXCEPTION_IF_NULL(abstract_none); | |||||
| recv_node_ptr->set_abstract(abstract_none); | |||||
| return recv_node_ptr; | |||||
| } | |||||
| void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | ||||
| if (!NeedInsertSwitch()) { | if (!NeedInsertSwitch()) { | ||||
| return; | return; | ||||
| @@ -93,21 +142,95 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| auto orders = kernel_graph_ptr->execution_order(); | |||||
| if (orders.empty()) { | |||||
| MS_LOG(EXCEPTION) << "graph execution order is empty"; | |||||
| } | |||||
| uint32_t first_cnode_stream_label = AnfAlgo::GetStreamDistinctionLabel(orders[0].get()); | |||||
| std::vector<CNodePtr> exec_order; | std::vector<CNodePtr> exec_order; | ||||
| CNodePtr stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | |||||
| MS_EXCEPTION_IF_NULL(stream_switch_app); | |||||
| exec_order.push_back(stream_switch_app); | |||||
| CNodePtr first_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | |||||
| MS_EXCEPTION_IF_NULL(first_stream_switch_app); | |||||
| AnfAlgo::SetStreamDistinctionLabel(kFirstStreamSwitchLabel, first_stream_switch_app.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(kGetNextLabel), first_stream_switch_app); | |||||
| CNodePtr second_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | |||||
| MS_EXCEPTION_IF_NULL(second_stream_switch_app); | |||||
| AnfAlgo::SetStreamDistinctionLabel(kSecondStreamSwitchLabel, second_stream_switch_app.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_cnode_stream_label), second_stream_switch_app); | |||||
| // add attr "stream_need_active" | |||||
| AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), second_stream_switch_app); | |||||
| CNodePtr first_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr); | |||||
| MS_EXCEPTION_IF_NULL(first_stream_active_app); | |||||
| AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, first_stream_active_app.get()); | |||||
| std::vector<uint32_t> first_active_streams = {kFirstStreamSwitchLabel}; | |||||
| AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(first_active_streams), | |||||
| first_stream_active_app); | |||||
| CNodePtr second_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr); | |||||
| MS_EXCEPTION_IF_NULL(second_stream_active_app); | |||||
| // specific deal for common ctrl stream policy | |||||
| uint32_t first_common_stream_switch_label = FindFirstStreamSwitchLabel(kernel_graph_ptr); | |||||
| if (first_common_stream_switch_label == kInvalidDistincLabel) { | |||||
| AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, second_stream_active_app.get()); | |||||
| } else { | |||||
| AnfAlgo::SetStreamDistinctionLabel(first_common_stream_switch_label, second_stream_active_app.get()); | |||||
| } | |||||
| CNodePtr stream_active_switch_app = CreateStreamActiveSwitchOp(kernel_graph_ptr); | |||||
| MS_EXCEPTION_IF_NULL(stream_active_switch_app); | |||||
| std::vector<uint32_t> second_active_streams = {kSecondStreamSwitchLabel}; | |||||
| AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(second_active_streams), | |||||
| second_stream_active_app); | |||||
| CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); | CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); | ||||
| MS_EXCEPTION_IF_NULL(assign_add_one); | MS_EXCEPTION_IF_NULL(assign_add_one); | ||||
| AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, assign_add_one.get()); | |||||
| CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId); | |||||
| AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, send.get()); | |||||
| CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId); | |||||
| AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, recv.get()); | |||||
| // reorder graph orders | |||||
| exec_order.push_back(first_stream_switch_app); | |||||
| size_t i = 0; | |||||
| for (; i < orders.size(); i++) { | |||||
| auto node = orders[i]; | |||||
| exec_order.push_back(node); | |||||
| AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, exec_order[exec_order.size() - 1].get()); | |||||
| if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| exec_order.push_back(send); | |||||
| exec_order.push_back(second_stream_switch_app); | |||||
| exec_order.push_back(recv); | |||||
| exec_order.push_back(assign_add_one); | exec_order.push_back(assign_add_one); | ||||
| auto original_exec_order = kernel_graph_ptr->execution_order(); | |||||
| (void)std::copy(original_exec_order.begin(), original_exec_order.end(), std::back_inserter(exec_order)); | |||||
| exec_order.push_back(stream_active_switch_app); | |||||
| std::vector<CNodePtr> memcpy_list; | |||||
| std::vector<CNodePtr> before_list; | |||||
| std::vector<CNodePtr> after_list; | |||||
| bool first_memcpy_found = false; | |||||
| CNodePtr cur_cnode = nullptr; | |||||
| for (size_t idx = i + 1; idx < orders.size(); idx++) { | |||||
| cur_cnode = orders[idx]; | |||||
| if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { | |||||
| memcpy_list.emplace_back(cur_cnode); | |||||
| first_memcpy_found = true; | |||||
| } else if (first_memcpy_found) { | |||||
| after_list.emplace_back(cur_cnode); | |||||
| } else { | |||||
| before_list.emplace_back(cur_cnode); | |||||
| } | |||||
| } | |||||
| (void)std::copy(before_list.begin(), before_list.end(), std::back_inserter(exec_order)); | |||||
| (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); | |||||
| exec_order.push_back(first_stream_active_app); | |||||
| (void)std::copy(after_list.begin(), after_list.end(), std::back_inserter(exec_order)); | |||||
| exec_order.push_back(second_stream_active_app); | |||||
| kernel_graph_ptr->set_execution_order(exec_order); | kernel_graph_ptr->set_execution_order(exec_order); | ||||
| } | } | ||||
| @@ -167,7 +290,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::Kerne | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( | kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( | ||||
| {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); | {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); | ||||
| auto typeNone_abstract = std::make_shared<abstract::AbstractNone>(); | auto typeNone_abstract = std::make_shared<abstract::AbstractNone>(); | ||||
| auto stream_switch = std::make_shared<Primitive>(kStreamSwitch); | |||||
| auto stream_switch = std::make_shared<Primitive>(kStreamSwitchOpName); | |||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| inputs.push_back(NewValueNode(stream_switch)); | inputs.push_back(NewValueNode(stream_switch)); | ||||
| inputs.push_back(switch_loop_input.at(kLoopCountParamName)); | inputs.push_back(switch_loop_input.at(kLoopCountParamName)); | ||||
| @@ -181,28 +304,19 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::Kerne | |||||
| int condition = static_cast<int>(RT_LESS); | int condition = static_cast<int>(RT_LESS); | ||||
| ValuePtr cond = MakeValue(condition); | ValuePtr cond = MakeValue(condition); | ||||
| AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app); | AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app); | ||||
| // set attr:true branch graph id ,which is same to stream distinction label | |||||
| if (kernel_graph_ptr->execution_order().empty()) { | |||||
| MS_LOG(EXCEPTION) << "empty execution order"; | |||||
| } | |||||
| auto first_node = kernel_graph_ptr->execution_order()[0]; | |||||
| auto first_stream = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); | |||||
| AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_stream), stream_switch_app); | |||||
| // set attr:data_type | // set attr:data_type | ||||
| int data_type = static_cast<int>(RT_SWITCH_INT64); | int data_type = static_cast<int>(RT_SWITCH_INT64); | ||||
| ValuePtr dt = MakeValue(data_type); | ValuePtr dt = MakeValue(data_type); | ||||
| AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); | AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); | ||||
| // set distinction label and graph id | // set distinction label and graph id | ||||
| AnfAlgo::SetGraphId(kInvalidGraphId - 1, stream_switch_app.get()); | |||||
| AnfAlgo::SetStreamDistinctionLabel(kInvalidDistincLabel - 1, stream_switch_app.get()); | |||||
| return stream_switch_app; | return stream_switch_app; | ||||
| } | } | ||||
| CNodePtr KernelAdjust::CreateSteamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | |||||
| CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( | kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( | ||||
| {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); | {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); | ||||
| abstract::AbstractBasePtr typeNone_abstract = std::make_shared<abstract::AbstractNone>(); | abstract::AbstractBasePtr typeNone_abstract = std::make_shared<abstract::AbstractNone>(); | ||||
| auto stream_active_others = std::make_shared<Primitive>(kStreamActive); | |||||
| auto stream_active_others = std::make_shared<Primitive>(kStreamActiveOpName); | |||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| inputs.push_back(NewValueNode(stream_active_others)); | inputs.push_back(NewValueNode(stream_active_others)); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | ||||
| @@ -213,57 +327,6 @@ CNodePtr KernelAdjust::CreateSteamActiveOp(const std::shared_ptr<session::Kernel | |||||
| return stream_active_others_app; | return stream_active_others_app; | ||||
| } | } | ||||
| CNodePtr KernelAdjust::CreateStreamActiveSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( | |||||
| {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); | |||||
| abstract::AbstractBasePtr typeNone_abstract = std::make_shared<abstract::AbstractNone>(); | |||||
| auto stream_active_switch = std::make_shared<Primitive>(kStreamActive); | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(NewValueNode(stream_active_switch)); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | |||||
| CNodePtr stream_active_switch_app = kernel_graph_ptr->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(stream_active_switch_app); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_switch_app.get()); | |||||
| stream_active_switch_app->set_abstract(typeNone_abstract); | |||||
| // set attr,which stream to active | |||||
| std::vector<uint32_t> active_index_value = {kInvalidDistincLabel - 1}; | |||||
| auto value = MakeValue<std::vector<uint32_t>>(active_index_value); | |||||
| AnfAlgo::SetNodeAttr(kAttrActiveStreamList, value, stream_active_switch_app); | |||||
| // set the distinction label of stream active | |||||
| if (kernel_graph_ptr->execution_order().empty()) { | |||||
| MS_LOG(EXCEPTION) << "empty execution order"; | |||||
| } | |||||
| auto first_node = kernel_graph_ptr->execution_order()[0]; | |||||
| auto label = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); | |||||
| // find the first switch's distinction label | |||||
| for (auto node : kernel_graph_ptr->execution_order()) { | |||||
| if (AnfAlgo::GetCNodeName(node) == "StreamSwitch") { | |||||
| label = AnfAlgo::GetStreamDistinctionLabel(node.get()); | |||||
| break; | |||||
| } | |||||
| } | |||||
| AnfAlgo::SetStreamDistinctionLabel(label, stream_active_switch_app.get()); | |||||
| return stream_active_switch_app; | |||||
| } | |||||
| CNodePtr KernelAdjust::CreateStreamActiveOtherOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( | |||||
| {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); | |||||
| abstract::AbstractBasePtr typeNone_abstract = std::make_shared<abstract::AbstractNone>(); | |||||
| auto stream_active_others = std::make_shared<Primitive>(kStreamActive); | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.push_back(NewValueNode(stream_active_others)); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | |||||
| CNodePtr stream_active_others_app = kernel_graph_ptr->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(stream_active_others_app); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get()); | |||||
| stream_active_others_app->set_abstract(typeNone_abstract); | |||||
| // set attr | |||||
| ValuePtr active_target = MakeValue(kValueTargetOther); | |||||
| AnfAlgo::SetNodeAttr(kAttrActiveTarget, active_target, stream_active_others_app); | |||||
| return stream_active_others_app; | |||||
| } | |||||
| CNodePtr KernelAdjust::CreateStreamAssignAddnOP( | CNodePtr KernelAdjust::CreateStreamAssignAddnOP( | ||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | ||||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) { | const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) { | ||||
| @@ -273,7 +336,7 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP( | |||||
| selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); | selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); | ||||
| selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32}); | selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32}); | ||||
| // AssignAdd | // AssignAdd | ||||
| auto assign_add = std::make_shared<Primitive>(kAssignAdd); | |||||
| auto assign_add = std::make_shared<Primitive>(kAssignAddOpName); | |||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| inputs.push_back(NewValueNode(assign_add)); | inputs.push_back(NewValueNode(assign_add)); | ||||
| inputs.push_back(switch_loop_input.at(kLoopCountParamName)); | inputs.push_back(switch_loop_input.at(kLoopCountParamName)); | ||||
| @@ -290,70 +353,9 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP( | |||||
| selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); | selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); | ||||
| MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName)); | MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName)); | ||||
| assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract()); | assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract()); | ||||
| // set the distinction label of assign add | |||||
| if (kernel_graph_ptr->execution_order().empty()) { | |||||
| MS_LOG(EXCEPTION) << "empty execution order"; | |||||
| } | |||||
| auto first_node = kernel_graph_ptr->execution_order()[0]; | |||||
| auto label = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); | |||||
| AnfAlgo::SetStreamDistinctionLabel(label, assign_add_one.get()); | |||||
| return assign_add_one; | return assign_add_one; | ||||
| } | } | ||||
| void KernelAdjust::SetStreamActiveOPs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | |||||
| const std::unordered_set<uint32_t> &ctrl_stream_list, | |||||
| const std::unordered_set<uint32_t> &comm_stream_list, | |||||
| const std::unordered_set<uint32_t> &momentum_stream_list) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | |||||
| for (const auto &cnode_ptr : kernel_graph_ptr->execution_order()) { | |||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||||
| if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamActive) { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr); | |||||
| ValuePtr active_target = primitive->GetAttr(kAttrActiveTarget); | |||||
| std::vector<uint32_t> index_list; | |||||
| index_list.clear(); | |||||
| if (GetValue<string>(active_target) == kValueTargetSwitch) { | |||||
| index_list.insert(index_list.end(), ctrl_stream_list.begin(), ctrl_stream_list.end()); | |||||
| } else if (GetValue<string>(active_target) == kValueTargetOther) { | |||||
| for (uint32_t index : comm_stream_list) { | |||||
| if (AnfAlgo::GetStreamId(cnode_ptr) == index) { | |||||
| continue; | |||||
| } | |||||
| index_list.emplace_back(index); | |||||
| } | |||||
| index_list.insert(index_list.end(), momentum_stream_list.begin(), momentum_stream_list.end()); | |||||
| } | |||||
| ValuePtr index_list_value = MakeValue(index_list); | |||||
| AnfAlgo::SetNodeAttr(kAttrActiveStreamList, index_list_value, cnode_ptr); | |||||
| } | |||||
| } | |||||
| } | |||||
| void KernelAdjust::SetStreamSwitchOps(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph_ptr); | |||||
| CNodePtr switch_cnode_ptr = nullptr; | |||||
| uint32_t target_stream_id = 0; | |||||
| for (const auto &cnode_ptr : kernel_graph_ptr->execution_order()) { | |||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||||
| if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamSwitch) { | |||||
| switch_cnode_ptr = cnode_ptr; | |||||
| } | |||||
| if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamActive) { | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr); | |||||
| ValuePtr active_target = primitive->GetAttr(kAttrActiveTarget); | |||||
| if (GetValue<string>(active_target) == kValueTargetOther) { | |||||
| target_stream_id = AnfAlgo::GetStreamId(cnode_ptr); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (switch_cnode_ptr != nullptr) { | |||||
| // set attr:true stream | |||||
| ValuePtr true_index = MakeValue(target_stream_id); | |||||
| AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, true_index, switch_cnode_ptr); | |||||
| MS_LOG(INFO) << "switch to true_index:" << target_stream_id; | |||||
| } | |||||
| } | |||||
| bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context, | bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context, | ||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | ||||
| if (!NeedInsertSwitch()) { | if (!NeedInsertSwitch()) { | ||||
| @@ -28,10 +28,22 @@ | |||||
| #include "session/session_context.h" | #include "session/session_context.h" | ||||
| #include "ir/meta_tensor.h" | #include "ir/meta_tensor.h" | ||||
| #include "device/ascend/profiling/profiling_utils.h" | #include "device/ascend/profiling/profiling_utils.h" | ||||
| #include "device/kernel_info.h" | |||||
| using mindspore::device::ascend::ProfilingTraceInfo; | using mindspore::device::ascend::ProfilingTraceInfo; | ||||
| using mindspore::device::ascend::ProfilingUtils; | using mindspore::device::ascend::ProfilingUtils; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| constexpr auto kLoopCountParamName = "loop_count"; | |||||
| constexpr auto kIterLoopParamName = "iter_loop"; | |||||
| constexpr auto kZeroParamName = "zero"; | |||||
| constexpr auto kOneParamName = "one"; | |||||
| constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; | |||||
| const uint32_t kFirstStreamSwitchLabel = kInvalidDistincLabel - 1; | |||||
| const uint32_t kGetNextLabel = kInvalidDistincLabel - 2; | |||||
| const uint32_t kSecondStreamSwitchLabel = kInvalidDistincLabel - 3; | |||||
| const uint32_t kInvalidEventId = UINT32_MAX; | |||||
| const uint32_t kFirstEventId = kInvalidEventId / 2; | |||||
| namespace device { | namespace device { | ||||
| class KernelAdjust { | class KernelAdjust { | ||||
| public: | public: | ||||
| @@ -41,26 +53,23 @@ class KernelAdjust { | |||||
| } | } | ||||
| void Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | void Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | ||||
| void InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | void InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | ||||
| void SetStreamActiveOPs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | |||||
| const std::unordered_set<uint32_t> &ctrl_stream_list, | |||||
| const std::unordered_set<uint32_t> &comm_stream_list, | |||||
| const std::unordered_set<uint32_t> &momentum_stream_list); | |||||
| void SetStreamSwitchOps(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | |||||
| bool StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context, | bool StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context, | ||||
| const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | ||||
| void Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr); | void Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr); | ||||
| static bool NeedInsertSwitch(); | static bool NeedInsertSwitch(); | ||||
| CNodePtr CreateSteamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | |||||
| CNodePtr CreateStreamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | |||||
| private: | private: | ||||
| KernelAdjust() = default; | KernelAdjust() = default; | ||||
| ~KernelAdjust() = default; | ~KernelAdjust() = default; | ||||
| CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id); | |||||
| CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id); | |||||
| uint32_t FindFirstStreamSwitchLabel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | |||||
| void CreateSwitchOpParameters(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | void CreateSwitchOpParameters(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | ||||
| std::map<std::string, mindspore::ParameterPtr> *switch_loop_input); | std::map<std::string, mindspore::ParameterPtr> *switch_loop_input); | ||||
| CNodePtr CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | CNodePtr CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | ||||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input); | const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input); | ||||
| CNodePtr CreateStreamActiveSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | |||||
| CNodePtr CreateStreamActiveOtherOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); | |||||
| CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, | ||||
| const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input); | const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input); | ||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector<std::string> &formats, | kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector<std::string> &formats, | ||||
| @@ -62,6 +62,7 @@ | |||||
| #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" | #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" | ||||
| #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" | #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" | ||||
| #include "pre_activate/ascend/ir_fission/addn_fission.h" | #include "pre_activate/ascend/ir_fission/addn_fission.h" | ||||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" | |||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| @@ -187,6 +188,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); | ||||
| } | } | ||||
| if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { | |||||
| ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||||
| } | |||||
| optimizer->AddPassManager(ir_fusion_pm); | optimizer->AddPassManager(ir_fusion_pm); | ||||
| (void)optimizer->Optimize(kernel_graph); | (void)optimizer->Optimize(kernel_graph); | ||||
| kernel_graph->SetExecOrderByDefault(); | kernel_graph->SetExecOrderByDefault(); | ||||
| @@ -20,8 +20,8 @@ namespace mindspore { | |||||
| namespace memreuse { | namespace memreuse { | ||||
| void StreamReuse::SetStreamReuseResource() { | void StreamReuse::SetStreamReuseResource() { | ||||
| #ifdef ENABLE_D | #ifdef ENABLE_D | ||||
| auto logic_physic_map = device::ascend::AscendStreamAssign::GetInstance().GetPhysicMap(); | |||||
| auto logic_independent_map = device::ascend::AscendStreamAssign::GetInstance().GetIndependentMap(); | |||||
| auto logic_physic_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_physic_map(); | |||||
| auto logic_independent_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_independent_map(); | |||||
| MS_LOG(INFO) << "stream mem reuse for Davici"; | MS_LOG(INFO) << "stream mem reuse for Davici"; | ||||
| if (!logic_independent_map.empty() && !logic_physic_map.empty()) { | if (!logic_independent_map.empty() && !logic_physic_map.empty()) { | ||||
| set_logic_physic_map(logic_physic_map); | set_logic_physic_map(logic_physic_map); | ||||
| @@ -610,7 +610,7 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { | |||||
| if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && | if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && | ||||
| ConfigManager::GetInstance().iter_num() > 1) { | ConfigManager::GetInstance().iter_num() > 1) { | ||||
| // insert active in true graph, another active will be inserted in kernel adjust | // insert active in true graph, another active will be inserted in kernel adjust | ||||
| InsertStreamActiveToGraph(true_last_id, kInvalidDistincLabel - 1); | |||||
| InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel); | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -114,6 +114,9 @@ constexpr auto kFusedMulAddNOpName = "FusedMulAddN"; | |||||
| constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum"; | constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum"; | ||||
| constexpr auto kBiasAddOpName = "BiasAdd"; | constexpr auto kBiasAddOpName = "BiasAdd"; | ||||
| constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad"; | constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad"; | ||||
| constexpr auto kStreamSwitchOpName = "StreamSwitch"; | |||||
| constexpr auto kStreamActiveOpName = "StreamActive"; | |||||
| constexpr auto kAssignAddOpName = "AssignAdd"; | |||||
| constexpr auto kSendOpName = "Send"; | constexpr auto kSendOpName = "Send"; | ||||
| constexpr auto kRecvOpName = "Recv"; | constexpr auto kRecvOpName = "Recv"; | ||||
| constexpr auto kReluV2OpName = "ReluV2"; | constexpr auto kReluV2OpName = "ReluV2"; | ||||
| @@ -24,9 +24,7 @@ void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; | |||||
| uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; } | uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; } | ||||
| std::vector<uint32_t> AscendStreamAssign::GetWaitStreams() { return vector<uint32_t>(); } | |||||
| std::vector<uint32_t> AscendStreamAssign::GetHcomStreams() { return vector<uint32_t>(); } | |||||
| void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) { return; } | |||||
| namespace tasksink { | namespace tasksink { | ||||
| bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::vector<TaskInfoPtr> *const task_info_list, | bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::vector<TaskInfoPtr> *const task_info_list, | ||||