diff --git a/mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.cc b/mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.cc index 48ddec7577..5ae51db8f6 100644 --- a/mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.cc +++ b/mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.cc @@ -35,7 +35,7 @@ Status CallbackManager::Init() { } Status CallbackManager::CallbackProcess() { - std::pair> entry; + std::pair> entry; while (true) { if (!callback_queue_.Pop(&entry)) { MS_LOG(INFO) << "CallbackManager stopped"; @@ -84,7 +84,7 @@ Status CallbackManager::Destroy() { return ret; } -Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) { +Status CallbackManager::RegisterCallback(rtCallback_t callback, const void *user_data) { MS_LOG(INFO) << "To register callback"; rtEvent_t event = nullptr; auto ret = rtEventCreate(&event); @@ -98,8 +98,8 @@ Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) MS_LOG(ERROR) << "Record event failed"; return kFail; } - auto cb = std::pair(callback, user_data); - auto entry = std::pair>(event, std::move(cb)); + auto cb = std::pair(callback, user_data); + auto entry = std::pair>(event, std::move(cb)); if (!callback_queue_.Push(entry)) { return kFail; } @@ -108,9 +108,9 @@ Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) return kSuccess; } -void CallbackManager::RtCallbackFunc(void *data) { +void CallbackManager::RtCallbackFunc(const void *data) { MS_LOG(INFO) << "To invoke callback function"; - auto callback_func = reinterpret_cast *>(data); + auto callback_func = reinterpret_cast *>(data); (*callback_func)(); delete callback_func; } diff --git a/mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.h b/mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.h index 2084036f3e..9d87f26160 100644 --- a/mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.h +++ b/mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.h @@ -24,11 +24,10 @@ #include #include "profiler/device/ascend/blocking_queue.h" #include "runtime/base.h" - namespace mindspore { namespace profiler { namespace ascend { -using rtCallback_t = std::function; +using rtCallback_t = std::function; enum Status { kSuccess = 0, kFail, kInvalidParam }; class CallbackManager { public: @@ -45,14 +44,14 @@ class CallbackManager { Status Destroy(); - Status RegisterCallback(rtCallback_t callback, void *user_data); + Status RegisterCallback(rtCallback_t callback, const void *user_data); Status RegisterCallback(const std::function &callback); private: Status CallbackProcess(); - static void RtCallbackFunc(void *data); + static void RtCallbackFunc(const void *data); - BlockingQueue>> callback_queue_; + BlockingQueue>> callback_queue_; rtStream_t stream_; std::future ret_future_; }; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index da66d338fb..f7b33b30a9 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -38,6 +38,7 @@ void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) Reset(); SetLoopSink(); ReorderIndependentOrders(graph_ptr); + AssignAllNodesStream(graph_ptr); UpdateAtomicAddrCleanStreamId(graph_ptr); InsertStreamActive(graph_ptr); @@ -1438,19 +1439,19 @@ void AscendStreamAssign::Reset() { } // section 10 -bool AscendStreamAssign::IsVecExist(std::vector *group) { - auto group_size = group->size(); +bool AscendStreamAssign::IsVecExist(const std::vector &group) { + auto group_size = group.size(); if (group_size == 0) { return false; } for (const auto &item : stream_groups_) { - if (item.size() < group->size()) { + if (item.size() < group.size()) { continue; } bool flag = true; for (size_t i = 0; i < group_size; i++) { - if (item[i] != group->at(i)) { + if (item[i] != group.at(i)) { flag = false; break; } @@ -1469,7 +1470,7 @@ bool AscendStreamAssign::IsVecExist(std::vector *group) { void AscendStreamAssign::DFS(uint32_t start, std::vector *group) { auto it = stream_relations_.find(start); if (it == stream_relations_.end()) { - if (!IsVecExist(group)) { + if (!IsVecExist(*group)) { stream_groups_.emplace_back(*group); } else { MS_LOG(WARNING) << "DFS find same stream group, Not expected"; @@ -1781,7 +1782,6 @@ void AscendStreamAssign::FindEventRelations(const NotNull &graph MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr(item.first, kAttrEventId); } } - } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h index 8fc2e9427d..138892073a 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h @@ -172,7 +172,7 @@ class AscendStreamAssign { // function for memory resue void GetStreamRelations(); void DFS(uint32_t start, std::vector *group); - bool IsVecExist(std::vector *group); + bool IsVecExist(const std::vector &group); void FindStreamRelations(const NotNull &graph_ptr); void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr); void GetStreamActiveStreamRelation(const NotNull &graph_ptr, size_t index);