Merge pull request !2490 from gukecai/new-stream-for-committags/v0.6.0-beta
| @@ -426,23 +426,25 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); | AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); | ||||
| AscendStreamMng &stream_manager = AscendStreamMng::GetInstance(); | |||||
| AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); | |||||
| AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); | AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); | ||||
| // the streams' flag not HEAD_STREAM | // the streams' flag not HEAD_STREAM | ||||
| std::vector<uint32_t> wait_active_stream_list; | std::vector<uint32_t> wait_active_stream_list; | ||||
| assign_instance.GetWaitStreams(&wait_active_stream_list); | assign_instance.GetWaitStreams(&wait_active_stream_list); | ||||
| std::vector<uint32_t> force_copy_stream_list; | std::vector<uint32_t> force_copy_stream_list; | ||||
| assign_instance.GetHcomStreams(&force_copy_stream_list); | assign_instance.GetHcomStreams(&force_copy_stream_list); | ||||
| MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_manager.GetCurAllocStreamNum() | |||||
| << ", total event num:" << assign_instance.total_event_num() | |||||
| MS_LOG(INFO) << "call DavinciModel total stream num:" << resource_manager.get_cur_stream_num() | |||||
| << ", total event num:" << resource_manager.get_cur_event_num() | |||||
| << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) | << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) | ||||
| << ", wait_active_stream_list size:" << wait_active_stream_list.size() | << ", wait_active_stream_list size:" << wait_active_stream_list.size() | ||||
| << ", force_copy_stream_list size:" << force_copy_stream_list.size(); | << ", force_copy_stream_list size:" << force_copy_stream_list.size(); | ||||
| std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list; | std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list; | ||||
| std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>( | std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>( | ||||
| task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, | task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, | ||||
| 0, 0, 0, 0, 0, stream_manager.GetCurAllocStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), | |||||
| assign_instance.total_event_num(), 0); | |||||
| 0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), | |||||
| resource_manager.get_cur_event_num(), 0); | |||||
| auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); | auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); | ||||
| if (!ret.second) { | if (!ret.second) { | ||||
| MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; | MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "runtime/rt_model.h" | #include "runtime/rt_model.h" | ||||
| #include "runtime/stream.h" | #include "runtime/stream.h" | ||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "utils/contract.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| @@ -38,35 +39,59 @@ using std::shared_ptr; | |||||
| using std::unordered_map; | using std::unordered_map; | ||||
| using std::unordered_set; | using std::unordered_set; | ||||
| using std::vector; | using std::vector; | ||||
| using CnodeKey = void *; | |||||
| const uint32_t kInvalidStreamId = UINT32_MAX; | const uint32_t kInvalidStreamId = UINT32_MAX; | ||||
| class AscendStreamMng { | |||||
| const uint32_t kInvalidEventId = UINT32_MAX; | |||||
| class AscendResourceMng { | |||||
| public: | public: | ||||
| static AscendStreamMng &GetInstance() { | |||||
| static AscendStreamMng instance; | |||||
| static AscendResourceMng &GetInstance() { | |||||
| static AscendResourceMng instance; | |||||
| return instance; | return instance; | ||||
| } | } | ||||
| void Reset() { | |||||
| cur_stream_id = 0; | |||||
| cur_stream_num = 0; | |||||
| void ResetResource() { | |||||
| cur_stream_num_ = 0; | |||||
| cur_event_num_ = 0; | |||||
| } | } | ||||
| uint32_t ApplyNewStream() { | uint32_t ApplyNewStream() { | ||||
| if (!cur_stream_num) { | |||||
| cur_stream_num++; | |||||
| if (!cur_stream_num_) { | |||||
| uint32_t cur_stream_id = cur_stream_num_; | |||||
| cur_stream_num_++; | |||||
| return cur_stream_id; | return cur_stream_id; | ||||
| } | } | ||||
| cur_stream_num++; | |||||
| cur_stream_id++; | |||||
| uint32_t cur_stream_id = cur_stream_num_; | |||||
| cur_stream_num_++; | |||||
| return cur_stream_id; | return cur_stream_id; | ||||
| } | } | ||||
| uint32_t ApplyNewEvent() { | |||||
| if (!cur_event_num_) { | |||||
| uint32_t cur_event_id = cur_event_num_; | |||||
| cur_event_num_++; | |||||
| return cur_event_id; | |||||
| } | |||||
| uint32_t cur_event_id = cur_event_num_; | |||||
| cur_event_num_++; | |||||
| return cur_event_id; | |||||
| } | |||||
| uint32_t GetCurAllocStream() { return cur_stream_id; } | |||||
| uint32_t GetCurAllocStreamNum() { return cur_stream_num; } | |||||
| void DeleteEvent() { | |||||
| if (!cur_event_num_) { | |||||
| MS_LOG(WARNING) << "total event num is 0, no event to delete"; | |||||
| } else { | |||||
| --cur_event_num_; | |||||
| } | |||||
| } | |||||
| uint32_t get_cur_stream_num() { return cur_stream_num_; } | |||||
| uint32_t GetCurAllocStreamId() { | |||||
| if (!cur_stream_num_) { | |||||
| MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get"; | |||||
| } | |||||
| return cur_stream_num_ - 1; | |||||
| } | |||||
| uint32_t get_cur_event_num() { return cur_event_num_; } | |||||
| private: | private: | ||||
| uint32_t cur_stream_num{0}; | |||||
| uint32_t cur_stream_id{0}; | |||||
| uint32_t cur_stream_num_{0}; | |||||
| uint32_t cur_event_num_{0}; | |||||
| }; | }; | ||||
| class AscendStreamAssign { | class AscendStreamAssign { | ||||
| @@ -79,39 +104,42 @@ class AscendStreamAssign { | |||||
| AscendStreamAssign(const AscendStreamAssign &) = delete; | AscendStreamAssign(const AscendStreamAssign &) = delete; | ||||
| AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; | AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; | ||||
| uint32_t total_event_num() const { return total_event_num_; } | |||||
| void AssignStream(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void GetHcomStreams(std::vector<uint32_t> *streams); | void GetHcomStreams(std::vector<uint32_t> *streams); | ||||
| void AssignStream(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void GetWaitStreams(vector<uint32_t> *wait_active_stream_list); | void GetWaitStreams(vector<uint32_t> *wait_active_stream_list); | ||||
| CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id, | |||||
| uint32_t stream_id); | |||||
| CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id, | |||||
| uint32_t stream_id); | |||||
| CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id); | |||||
| CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id); | |||||
| private: | private: | ||||
| AscendStreamAssign() = default; | AscendStreamAssign() = default; | ||||
| ~AscendStreamAssign() = default; | ~AscendStreamAssign() = default; | ||||
| void Reset(); | void Reset(); | ||||
| void CheckStreamAssign(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index, | |||||
| uint32_t *cur_stream_id); | |||||
| void CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr); | |||||
| void AssignHcomStreamId(const CNodePtr &cur_cnode_ptr); | |||||
| void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr); | void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr); | ||||
| void UpdateAtomicAddrCleanStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void FindHcomParallelStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void InsertStreamActive(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void UpdateStreamSwitch(const std::shared_ptr<session::KernelGraph> &graph_ptr, const CNodePtr &switch_ptr, | |||||
| const vector<uint32_t> &independent_stream, vector<CNodePtr> *orders); | |||||
| void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr); | |||||
| void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr); | |||||
| void UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void FindHcomParallelStreams(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr, | |||||
| vector<CNodePtr> *orders); | |||||
| void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, const map<uint32_t, vector<size_t>> &hcom_index, | |||||
| uint32_t first_hcom_stream, uint32_t last_hcom_stream); | |||||
| bool IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>> &hcom_index, const CNodePtr &node_ptr, size_t index); | |||||
| void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| void ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr); | |||||
| bool IsTaskSink(); | bool IsTaskSink(); | ||||
| bool IsFusionHcom(const CNodePtr &cur_cnode_ptr); | |||||
| bool IsHcom(const CNodePtr &cur_cnode_ptr); | bool IsHcom(const CNodePtr &cur_cnode_ptr); | ||||
| bool IsIndependentNode(const CNodePtr &node_ptr); | bool IsIndependentNode(const CNodePtr &node_ptr); | ||||
| bool IsProcessedStream(uint32_t stream_id); | bool IsProcessedStream(uint32_t stream_id); | ||||
| @@ -119,14 +147,13 @@ class AscendStreamAssign { | |||||
| const CNodePtr &node); | const CNodePtr &node); | ||||
| void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams); | void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams); | ||||
| uint32_t total_event_num_{0}; | |||||
| bool independent_stream_activated_{false}; | bool independent_stream_activated_{false}; | ||||
| bool hcom_stream_activated_{false}; | |||||
| std::map<uint32_t, uint32_t> independent_stream_map_{}; | std::map<uint32_t, uint32_t> independent_stream_map_{}; | ||||
| std::map<uint32_t, uint32_t> hcom_stream_map_{}; | |||||
| std::map<uint32_t, uint32_t> common_stream_map_{}; | |||||
| std::set<uint32_t> processed_streams_{}; | std::set<uint32_t> processed_streams_{}; | ||||
| std::set<uint32_t> hcom_stream_list_{}; | |||||
| std::vector<uint32_t> need_first_active_streams_{}; | std::vector<uint32_t> need_first_active_streams_{}; | ||||
| std::vector<std::vector<uint32_t>> inner_parallel_streams_{}; | |||||
| // new policy end | // new policy end | ||||
| }; | }; | ||||
| } // namespace ascend | } // namespace ascend | ||||
| @@ -103,8 +103,8 @@ CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::Kern | |||||
| } | } | ||||
| void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | ||||
| device::ascend::AscendStreamMng &stream_manager = device::ascend::AscendStreamMng::GetInstance(); | |||||
| stream_manager.Reset(); | |||||
| device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); | |||||
| resource_manager.ResetResource(); | |||||
| if (!NeedInsertSwitch()) { | if (!NeedInsertSwitch()) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -135,17 +135,16 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| } | } | ||||
| std::vector<CNodePtr> exec_order; | std::vector<CNodePtr> exec_order; | ||||
| // getnext loop process | // getnext loop process | ||||
| // getnext loop stream switch op | // getnext loop stream switch op | ||||
| CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | ||||
| MS_EXCEPTION_IF_NULL(getnext_switch_app); | MS_EXCEPTION_IF_NULL(getnext_switch_app); | ||||
| uint32_t getnext_switch_stream_id = stream_manager.ApplyNewStream(); | |||||
| uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream(); | |||||
| AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); | AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); | ||||
| exec_order.push_back(getnext_switch_app); | exec_order.push_back(getnext_switch_app); | ||||
| // getnext op | // getnext op | ||||
| uint32_t getnext_stream_id = stream_manager.ApplyNewStream(); | |||||
| uint32_t getnext_stream_id = resource_manager.ApplyNewStream(); | |||||
| size_t i = 0; | size_t i = 0; | ||||
| for (; i < orders.size(); i++) { | for (; i < orders.size(); i++) { | ||||
| auto node = orders[i]; | auto node = orders[i]; | ||||
| @@ -160,7 +159,8 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app); | AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app); | ||||
| // getnext loop send | // getnext loop send | ||||
| CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId); | |||||
| uint32_t getnext_event_id = resource_manager.ApplyNewEvent(); | |||||
| CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, getnext_event_id); | |||||
| AnfAlgo::SetStreamId(getnext_stream_id, send.get()); | AnfAlgo::SetStreamId(getnext_stream_id, send.get()); | ||||
| exec_order.push_back(send); | exec_order.push_back(send); | ||||
| @@ -168,14 +168,14 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> | |||||
| // fpbp loop stream switch | // fpbp loop stream switch | ||||
| CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); | ||||
| MS_EXCEPTION_IF_NULL(fpbp_switch_app); | MS_EXCEPTION_IF_NULL(fpbp_switch_app); | ||||
| uint32_t fpbp_switch_stream_id = stream_manager.ApplyNewStream(); | |||||
| uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream(); | |||||
| AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); | AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); | ||||
| AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app); | AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app); | ||||
| exec_order.push_back(fpbp_switch_app); | exec_order.push_back(fpbp_switch_app); | ||||
| // fpbp loop recv | // fpbp loop recv | ||||
| CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId); | |||||
| uint32_t fpbp_stream_id = stream_manager.ApplyNewStream(); | |||||
| CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, getnext_event_id); | |||||
| uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); | |||||
| AnfAlgo::SetStreamId(fpbp_stream_id, recv.get()); | AnfAlgo::SetStreamId(fpbp_stream_id, recv.get()); | ||||
| exec_order.push_back(recv); | exec_order.push_back(recv); | ||||
| @@ -38,12 +38,8 @@ constexpr auto kIterLoopParamName = "iter_loop"; | |||||
| constexpr auto kZeroParamName = "zero"; | constexpr auto kZeroParamName = "zero"; | ||||
| constexpr auto kOneParamName = "one"; | constexpr auto kOneParamName = "one"; | ||||
| constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; | constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; | ||||
| constexpr uint32_t kSecondStreamSwitchLabel = 2; | |||||
| const uint32_t kFirstStreamSwitchLabel = 0; | |||||
| const uint32_t kGetNextLabel = 1; | |||||
| const uint32_t kSecondStreamSwitchLabel = 2; | |||||
| const uint32_t kInvalidEventId = UINT32_MAX; | |||||
| const uint32_t kFirstEventId = kInvalidEventId / 2; | |||||
| namespace device { | namespace device { | ||||
| class KernelAdjust { | class KernelAdjust { | ||||
| public: | public: | ||||
| @@ -305,7 +305,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| // adjust kernel | // adjust kernel | ||||
| AdjustKernel(root_graph); | AdjustKernel(root_graph); | ||||
| // assign stream | // assign stream | ||||
| AssignStream(root_graph); | |||||
| AssignStream(NOT_NULL(root_graph)); | |||||
| // insert profiling point | // insert profiling point | ||||
| device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); | device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); | ||||
| // build kernel | // build kernel | ||||
| @@ -377,7 +377,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { | |||||
| // adjust execution order because merge child graph and other special operations | // adjust execution order because merge child graph and other special operations | ||||
| AdjustKernel(graph); | AdjustKernel(graph); | ||||
| // Assign streams for control sink and hccl and so on | // Assign streams for control sink and hccl and so on | ||||
| AssignStream(graph); | |||||
| AssignStream(NOT_NULL(graph)); | |||||
| device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get())); | device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get())); | ||||
| // build kernel if node is cnode | // build kernel if node is cnode | ||||
| @@ -647,7 +647,7 @@ void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel | |||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| } | } | ||||
| void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||||
| void AscendSession::AssignStream(NotNull<KernelGraphPtr> kernel_graph) const { | |||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph); | device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph); | ||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| @@ -76,7 +76,7 @@ class AscendSession : public SessionBasic { | |||||
| void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||||
| void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const; | |||||
| void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const; | void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const; | ||||
| void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void MemoryAlloc(KernelGraph *kernel_graph) const; | void MemoryAlloc(KernelGraph *kernel_graph) const; | ||||
| @@ -26,7 +26,7 @@ void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph | |||||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { return 1; } | uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { return 1; } | ||||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { return 1; } | uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { return 1; } | ||||
| void AscendStreamAssign::AssignStream(const KernelGraphPtr &graph) { return; } | |||||
| void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) { return; } | |||||
| void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) { return; } | void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) { return; } | ||||