| @@ -35,7 +35,7 @@ Status CallbackManager::Init() { | |||
| } | |||
| Status CallbackManager::CallbackProcess() { | |||
| std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> entry; | |||
| std::pair<rtEvent_t, std::pair<rtCallback_t, const void *>> entry; | |||
| while (true) { | |||
| if (!callback_queue_.Pop(&entry)) { | |||
| MS_LOG(INFO) << "CallbackManager stopped"; | |||
| @@ -84,7 +84,7 @@ Status CallbackManager::Destroy() { | |||
| return ret; | |||
| } | |||
| Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) { | |||
| Status CallbackManager::RegisterCallback(rtCallback_t callback, const void *user_data) { | |||
| MS_LOG(INFO) << "To register callback"; | |||
| rtEvent_t event = nullptr; | |||
| auto ret = rtEventCreate(&event); | |||
| @@ -98,8 +98,8 @@ Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) | |||
| MS_LOG(ERROR) << "Record event failed"; | |||
| return kFail; | |||
| } | |||
| auto cb = std::pair<rtCallback_t, void *>(callback, user_data); | |||
| auto entry = std::pair<rtEvent_t, std::pair<rtCallback_t, void *>>(event, std::move(cb)); | |||
| auto cb = std::pair<rtCallback_t, const void *>(callback, user_data); | |||
| auto entry = std::pair<rtEvent_t, std::pair<rtCallback_t, const void *>>(event, std::move(cb)); | |||
| if (!callback_queue_.Push(entry)) { | |||
| return kFail; | |||
| } | |||
| @@ -108,9 +108,9 @@ Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) | |||
| return kSuccess; | |||
| } | |||
| void CallbackManager::RtCallbackFunc(void *data) { | |||
| void CallbackManager::RtCallbackFunc(const void *data) { | |||
| MS_LOG(INFO) << "To invoke callback function"; | |||
| auto callback_func = reinterpret_cast<std::function<void()> *>(data); | |||
| auto callback_func = reinterpret_cast<const std::function<void()> *>(data); | |||
| (*callback_func)(); | |||
| delete callback_func; | |||
| } | |||
| @@ -24,11 +24,10 @@ | |||
| #include <utility> | |||
| #include "profiler/device/ascend/blocking_queue.h" | |||
| #include "runtime/base.h" | |||
| namespace mindspore { | |||
| namespace profiler { | |||
| namespace ascend { | |||
| using rtCallback_t = std::function<void(void *)>; | |||
| using rtCallback_t = std::function<void(const void *)>; | |||
| enum Status { kSuccess = 0, kFail, kInvalidParam }; | |||
| class CallbackManager { | |||
| public: | |||
| @@ -45,14 +44,14 @@ class CallbackManager { | |||
| Status Destroy(); | |||
| Status RegisterCallback(rtCallback_t callback, void *user_data); | |||
| Status RegisterCallback(rtCallback_t callback, const void *user_data); | |||
| Status RegisterCallback(const std::function<void()> &callback); | |||
| private: | |||
| Status CallbackProcess(); | |||
| static void RtCallbackFunc(void *data); | |||
| static void RtCallbackFunc(const void *data); | |||
| BlockingQueue<std::pair<rtEvent_t, std::pair<rtCallback_t, void *>>> callback_queue_; | |||
| BlockingQueue<std::pair<rtEvent_t, std::pair<rtCallback_t, const void *>>> callback_queue_; | |||
| rtStream_t stream_; | |||
| std::future<Status> ret_future_; | |||
| }; | |||
| @@ -38,6 +38,7 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) | |||
| Reset(); | |||
| SetLoopSink(); | |||
| ReorderIndependentOrders(graph_ptr); | |||
| AssignAllNodesStream(graph_ptr); | |||
| UpdateAtomicAddrCleanStreamId(graph_ptr); | |||
| InsertStreamActive(graph_ptr); | |||
| @@ -1438,19 +1439,19 @@ void AscendStreamAssign::Reset() { | |||
| } | |||
| // section 10 | |||
| bool AscendStreamAssign::IsVecExist(std::vector<uint32_t> *group) { | |||
| auto group_size = group->size(); | |||
| bool AscendStreamAssign::IsVecExist(const std::vector<uint32_t> &group) { | |||
| auto group_size = group.size(); | |||
| if (group_size == 0) { | |||
| return false; | |||
| } | |||
| for (const auto &item : stream_groups_) { | |||
| if (item.size() < group->size()) { | |||
| if (item.size() < group.size()) { | |||
| continue; | |||
| } | |||
| bool flag = true; | |||
| for (size_t i = 0; i < group_size; i++) { | |||
| if (item[i] != group->at(i)) { | |||
| if (item[i] != group.at(i)) { | |||
| flag = false; | |||
| break; | |||
| } | |||
| @@ -1469,7 +1470,7 @@ bool AscendStreamAssign::IsVecExist(std::vector<uint32_t> *group) { | |||
| void AscendStreamAssign::DFS(uint32_t start, std::vector<uint32_t> *group) { | |||
| auto it = stream_relations_.find(start); | |||
| if (it == stream_relations_.end()) { | |||
| if (!IsVecExist(group)) { | |||
| if (!IsVecExist(*group)) { | |||
| stream_groups_.emplace_back(*group); | |||
| } else { | |||
| MS_LOG(WARNING) << "DFS find same stream group, Not expected"; | |||
| @@ -1781,7 +1782,6 @@ void AscendStreamAssign::FindEventRelations(const NotNull<KernelGraphPtr> &graph | |||
| MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId); | |||
| } | |||
| } | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -172,7 +172,7 @@ class AscendStreamAssign { | |||
| // function for memory resue | |||
| void GetStreamRelations(); | |||
| void DFS(uint32_t start, std::vector<uint32_t> *group); | |||
| bool IsVecExist(std::vector<uint32_t> *group); | |||
| bool IsVecExist(const std::vector<uint32_t> &group); | |||
| void FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr); | |||
| void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr); | |||
| void GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index); | |||