Merge pull request !2490 from gukecai/new-stream-for-committags/v0.6.0-beta
| @@ -426,23 +426,25 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { | |||
| return true; | |||
| } | |||
| AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); | |||
| AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); | |||
| AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); | |||
| AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); | |||
| // the streams' flag not HEAD_STREAM | |||
| std::vector<uint32_t> wait_active_stream_list; | |||
| assign_instance.GetWaitStreams(&wait_active_stream_list); | |||
| std::vector<uint32_t> force_copy_stream_list; | |||
| assign_instance.GetHcomStreams(&force_copy_stream_list); | |||
| MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_manager.GetCurAllocStreamNum() | |||
| << ", total event num:" << assign_instance.total_event_num() | |||
| MS_LOG(INFO) << "call DavinciModel total stream num:" << resource_manager.get_cur_stream_num() | |||
| << ", total event num:" << resource_manager.get_cur_event_num() | |||
| << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) | |||
| << ", wait_active_stream_list size:" << wait_active_stream_list.size() | |||
| << ", force_copy_stream_list size:" << force_copy_stream_list.size(); | |||
| std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list; | |||
| std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>( | |||
| task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, | |||
| 0, 0, 0, 0, 0, stream_manager.GetCurAllocStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), | |||
| assign_instance.total_event_num(), 0); | |||
| 0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), | |||
| resource_manager.get_cur_event_num(), 0); | |||
| auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); | |||
| if (!ret.second) { | |||
| MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; | |||
| @@ -29,6 +29,7 @@ | |||
| #include "runtime/rt_model.h" | |||
| #include "runtime/stream.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "utils/contract.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -38,35 +39,59 @@ using std::shared_ptr; | |||
| using std::unordered_map; | |||
| using std::unordered_set; | |||
| using std::vector; | |||
| using CnodeKey = void *; | |||
| const uint32_t kInvalidStreamId = UINT32_MAX; | |||
| class AscendStreamMng { | |||
| const uint32_t kInvalidEventId = UINT32_MAX; | |||
| class AscendResourceMng { | |||
| public: | |||
| static AscendStreamMng &GetInstance() { | |||
| static AscendStreamMng instance; | |||
| static AscendResourceMng &GetInstance() { | |||
| static AscendResourceMng instance; | |||
| return instance; | |||
| } | |||
| void Reset() { | |||
| cur_stream_id = 0; | |||
| cur_stream_num = 0; | |||
| void ResetResource() { | |||
| cur_stream_num_ = 0; | |||
| cur_event_num_ = 0; | |||
| } | |||
| uint32_t ApplyNewStream() { | |||
| if (!cur_stream_num) { | |||
| cur_stream_num++; | |||
| if (!cur_stream_num_) { | |||
| uint32_t cur_stream_id = cur_stream_num_; | |||
| cur_stream_num_++; | |||
| return cur_stream_id; | |||
| } | |||
| cur_stream_num++; | |||
| cur_stream_id++; | |||
| uint32_t cur_stream_id = cur_stream_num_; | |||
| cur_stream_num_++; | |||
| return cur_stream_id; | |||
| } | |||
| uint32_t ApplyNewEvent() { | |||
| if (!cur_event_num_) { | |||
| uint32_t cur_event_id = cur_event_num_; | |||
| cur_event_num_++; | |||
| return cur_event_id; | |||
| } | |||
| uint32_t cur_event_id = cur_event_num_; | |||
| cur_event_num_++; | |||
| return cur_event_id; | |||
| } | |||
| uint32_t GetCurAllocStream() { return cur_stream_id; } | |||
| uint32_t GetCurAllocStreamNum() { return cur_stream_num; } | |||
| void DeleteEvent() { | |||
| if (!cur_event_num_) { | |||
| MS_LOG(WARNING) << "total event num is 0, no event to delete"; | |||
| } else { | |||
| --cur_event_num_; | |||
| } | |||
| } | |||
| uint32_t get_cur_stream_num() { return cur_stream_num_; } | |||
| uint32_t GetCurAllocStreamId() { | |||
| if (!cur_stream_num_) { | |||
| MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get"; | |||
| } | |||
| return cur_stream_num_ - 1; | |||
| } | |||
| uint32_t get_cur_event_num() { return cur_event_num_; } | |||
| private: | |||
| uint32_t cur_stream_num{0}; | |||
| uint32_t cur_stream_id{0}; | |||
| uint32_t cur_stream_num_{0}; | |||
| uint32_t cur_event_num_{0}; | |||
| }; | |||
| class AscendStreamAssign { | |||
| @@ -79,39 +104,42 @@ class AscendStreamAssign { | |||
| AscendStreamAssign(const AscendStreamAssign &) = delete; | |||
| AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; | |||
| uint32_t total_event_num() const { return total_event_num_; } | |||
| void AssignStream(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void GetHcomStreams(std::vector<uint32_t> *streams); | |||
| void AssignStream(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void GetWaitStreams(vector<uint32_t> *wait_active_stream_list); | |||
| CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id, | |||
| uint32_t stream_id); | |||
| CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id, | |||
| uint32_t stream_id); | |||
| CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id); | |||
| CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id); | |||
| private: | |||
| AscendStreamAssign() = default; | |||
| ~AscendStreamAssign() = default; | |||
| void Reset(); | |||
| void CheckStreamAssign(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index, | |||
| uint32_t *cur_stream_id); | |||
| void CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr); | |||
| void AssignHcomStreamId(const CNodePtr &cur_cnode_ptr); | |||
| void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr); | |||
| void UpdateAtomicAddrCleanStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void FindHcomParallelStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void InsertStreamActive(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void UpdateStreamSwitch(const std::shared_ptr<session::KernelGraph> &graph_ptr, const CNodePtr &switch_ptr, | |||
| const vector<uint32_t> &independent_stream, vector<CNodePtr> *orders); | |||
| void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr); | |||
| void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||
| void UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void FindHcomParallelStreams(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr, | |||
| vector<CNodePtr> *orders); | |||
| void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, const map<uint32_t, vector<size_t>> &hcom_index, | |||
| uint32_t first_hcom_stream, uint32_t last_hcom_stream); | |||
| bool IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>> &hcom_index, const CNodePtr &node_ptr, size_t index); | |||
| void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| bool IsTaskSink(); | |||
| bool IsFusionHcom(const CNodePtr &cur_cnode_ptr); | |||
| bool IsHcom(const CNodePtr &cur_cnode_ptr); | |||
| bool IsIndependentNode(const CNodePtr &node_ptr); | |||
| bool IsProcessedStream(uint32_t stream_id); | |||
| @@ -119,14 +147,13 @@ class AscendStreamAssign { | |||
| const CNodePtr &node); | |||
| void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams); | |||
| uint32_t total_event_num_{0}; | |||
| bool independent_stream_activated_{false}; | |||
| bool hcom_stream_activated_{false}; | |||
| std::map<uint32_t, uint32_t> independent_stream_map_{}; | |||
| std::map<uint32_t, uint32_t> hcom_stream_map_{}; | |||
| std::map<uint32_t, uint32_t> common_stream_map_{}; | |||
| std::set<uint32_t> processed_streams_{}; | |||
| std::set<uint32_t> hcom_stream_list_{}; | |||
| std::vector<uint32_t> need_first_active_streams_{}; | |||
| std::vector<std::vector<uint32_t>> inner_parallel_streams_{}; | |||
| // new policy end | |||
| }; | |||
| } // namespace ascend | |||
| @@ -103,8 +103,8 @@ CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::Kern | |||
| } | |||
| void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | |||
| device::ascend::AscendStreamMng &stream_manager = device::ascend::AscendStreamMng::GetInstance(); | |||
| stream_manager.Reset(); | |||
| device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); | |||
| resource_manager.ResetResource(); | |||
| if (!NeedInsertSwitch()) { | |||
| return; | |||
| } | |||
| @@ -135,17 +135,16 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||
| } | |||
| std::vector<CNodePtr> exec_order; | |||
| // getnext loop process | |||
| // getnext loop stream switch op | |||
| CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | |||
| MS_EXCEPTION_IF_NULL(getnext_switch_app); | |||
| uint32_t getnext_switch_stream_id = stream_manager.ApplyNewStream(); | |||
| uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream(); | |||
| AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); | |||
| exec_order.push_back(getnext_switch_app); | |||
| // getnext op | |||
| uint32_t getnext_stream_id = stream_manager.ApplyNewStream(); | |||
| uint32_t getnext_stream_id = resource_manager.ApplyNewStream(); | |||
| size_t i = 0; | |||
| for (; i < orders.size(); i++) { | |||
| auto node = orders[i]; | |||
| @@ -160,7 +159,8 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||
| AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app); | |||
| // getnext loop send | |||
| CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId); | |||
| uint32_t getnext_event_id = resource_manager.ApplyNewEvent(); | |||
| CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, getnext_event_id); | |||
| AnfAlgo::SetStreamId(getnext_stream_id, send.get()); | |||
| exec_order.push_back(send); | |||
| @@ -168,14 +168,14 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||
| // fpbp loop stream switch | |||
| CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | |||
| MS_EXCEPTION_IF_NULL(fpbp_switch_app); | |||
| uint32_t fpbp_switch_stream_id = stream_manager.ApplyNewStream(); | |||
| uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream(); | |||
| AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); | |||
| AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app); | |||
| exec_order.push_back(fpbp_switch_app); | |||
| // fpbp loop recv | |||
| CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId); | |||
| uint32_t fpbp_stream_id = stream_manager.ApplyNewStream(); | |||
| CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, getnext_event_id); | |||
| uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); | |||
| AnfAlgo::SetStreamId(fpbp_stream_id, recv.get()); | |||
| exec_order.push_back(recv); | |||
| @@ -38,12 +38,8 @@ constexpr auto kIterLoopParamName = "iter_loop"; | |||
| constexpr auto kZeroParamName = "zero"; | |||
| constexpr auto kOneParamName = "one"; | |||
| constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; | |||
| constexpr uint32_t kSecondStreamSwitchLabel = 2; | |||
| const uint32_t kFirstStreamSwitchLabel = 0; | |||
| const uint32_t kGetNextLabel = 1; | |||
| const uint32_t kSecondStreamSwitchLabel = 2; | |||
| const uint32_t kInvalidEventId = UINT32_MAX; | |||
| const uint32_t kFirstEventId = kInvalidEventId / 2; | |||
| namespace device { | |||
| class KernelAdjust { | |||
| public: | |||
| @@ -305,7 +305,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| // adjust kernel | |||
| AdjustKernel(root_graph); | |||
| // assign stream | |||
| AssignStream(root_graph); | |||
| AssignStream(NOT_NULL(root_graph)); | |||
| // insert profiling point | |||
| device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); | |||
| // build kernel | |||
| @@ -377,7 +377,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { | |||
| // adjust execution order because merge child graph and other special operations | |||
| AdjustKernel(graph); | |||
| // Assign streams for control sink and hccl and so on | |||
| AssignStream(graph); | |||
| AssignStream(NOT_NULL(graph)); | |||
| device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get())); | |||
| // build kernel if node is cnode | |||
| @@ -647,7 +647,7 @@ void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel | |||
| MS_LOG(INFO) << "Finish!"; | |||
| } | |||
| void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||
| void AscendSession::AssignStream(NotNull<KernelGraphPtr> kernel_graph) const { | |||
| MS_LOG(INFO) << "Start!"; | |||
| device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph); | |||
| MS_LOG(INFO) << "Finish!"; | |||
| @@ -76,7 +76,7 @@ class AscendSession : public SessionBasic { | |||
| void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const; | |||
| void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const; | |||
| void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void MemoryAlloc(KernelGraph *kernel_graph) const; | |||
| @@ -26,7 +26,7 @@ void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph | |||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { return 1; } | |||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { return 1; } | |||
| void AscendStreamAssign::AssignStream(const KernelGraphPtr &graph) { return; } | |||
| void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) { return; } | |||
| void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) { return; } | |||