| @@ -35,7 +35,7 @@ Status CallbackManager::Init() { | |||||
| } | } | ||||
| Status CallbackManager::CallbackProcess() { | Status CallbackManager::CallbackProcess() { | ||||
| std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> entry; | |||||
| std::pair<rtEvent_t, std::pair<rtCallback_t, const void *>> entry; | |||||
| while (true) { | while (true) { | ||||
| if (!callback_queue_.Pop(&entry)) { | if (!callback_queue_.Pop(&entry)) { | ||||
| MS_LOG(INFO) << "CallbackManager stopped"; | MS_LOG(INFO) << "CallbackManager stopped"; | ||||
| @@ -84,7 +84,7 @@ Status CallbackManager::Destroy() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) { | |||||
| Status CallbackManager::RegisterCallback(rtCallback_t callback, const void *user_data) { | |||||
| MS_LOG(INFO) << "To register callback"; | MS_LOG(INFO) << "To register callback"; | ||||
| rtEvent_t event = nullptr; | rtEvent_t event = nullptr; | ||||
| auto ret = rtEventCreate(&event); | auto ret = rtEventCreate(&event); | ||||
| @@ -98,8 +98,8 @@ Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) | |||||
| MS_LOG(ERROR) << "Record event failed"; | MS_LOG(ERROR) << "Record event failed"; | ||||
| return kFail; | return kFail; | ||||
| } | } | ||||
| auto cb = std::pair<rtCallback_t, void *>(callback, user_data); | |||||
| auto entry = std::pair<rtEvent_t, std::pair<rtCallback_t, void *>>(event, std::move(cb)); | |||||
| auto cb = std::pair<rtCallback_t, const void *>(callback, user_data); | |||||
| auto entry = std::pair<rtEvent_t, std::pair<rtCallback_t, const void *>>(event, std::move(cb)); | |||||
| if (!callback_queue_.Push(entry)) { | if (!callback_queue_.Push(entry)) { | ||||
| return kFail; | return kFail; | ||||
| } | } | ||||
| @@ -108,9 +108,9 @@ Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) | |||||
| return kSuccess; | return kSuccess; | ||||
| } | } | ||||
| void CallbackManager::RtCallbackFunc(void *data) { | |||||
| void CallbackManager::RtCallbackFunc(const void *data) { | |||||
| MS_LOG(INFO) << "To invoke callback function"; | MS_LOG(INFO) << "To invoke callback function"; | ||||
| auto callback_func = reinterpret_cast<std::function<void()> *>(data); | |||||
| auto callback_func = reinterpret_cast<const std::function<void()> *>(data); | |||||
| (*callback_func)(); | (*callback_func)(); | ||||
| delete callback_func; | delete callback_func; | ||||
| } | } | ||||
| @@ -24,11 +24,10 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "profiler/device/ascend/blocking_queue.h" | #include "profiler/device/ascend/blocking_queue.h" | ||||
| #include "runtime/base.h" | #include "runtime/base.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace profiler { | namespace profiler { | ||||
| namespace ascend { | namespace ascend { | ||||
| using rtCallback_t = std::function<void(void *)>; | |||||
| using rtCallback_t = std::function<void(const void *)>; | |||||
| enum Status { kSuccess = 0, kFail, kInvalidParam }; | enum Status { kSuccess = 0, kFail, kInvalidParam }; | ||||
| class CallbackManager { | class CallbackManager { | ||||
| public: | public: | ||||
| @@ -45,14 +44,14 @@ class CallbackManager { | |||||
| Status Destroy(); | Status Destroy(); | ||||
| Status RegisterCallback(rtCallback_t callback, void *user_data); | |||||
| Status RegisterCallback(rtCallback_t callback, const void *user_data); | |||||
| Status RegisterCallback(const std::function<void()> &callback); | Status RegisterCallback(const std::function<void()> &callback); | ||||
| private: | private: | ||||
| Status CallbackProcess(); | Status CallbackProcess(); | ||||
| static void RtCallbackFunc(void *data); | |||||
| static void RtCallbackFunc(const void *data); | |||||
| BlockingQueue<std::pair<rtEvent_t, std::pair<rtCallback_t, void *>>> callback_queue_; | |||||
| BlockingQueue<std::pair<rtEvent_t, std::pair<rtCallback_t, const void *>>> callback_queue_; | |||||
| rtStream_t stream_; | rtStream_t stream_; | ||||
| std::future<Status> ret_future_; | std::future<Status> ret_future_; | ||||
| }; | }; | ||||
| @@ -38,6 +38,7 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) | |||||
| Reset(); | Reset(); | ||||
| SetLoopSink(); | SetLoopSink(); | ||||
| ReorderIndependentOrders(graph_ptr); | ReorderIndependentOrders(graph_ptr); | ||||
| AssignAllNodesStream(graph_ptr); | AssignAllNodesStream(graph_ptr); | ||||
| UpdateAtomicAddrCleanStreamId(graph_ptr); | UpdateAtomicAddrCleanStreamId(graph_ptr); | ||||
| InsertStreamActive(graph_ptr); | InsertStreamActive(graph_ptr); | ||||
| @@ -1438,19 +1439,19 @@ void AscendStreamAssign::Reset() { | |||||
| } | } | ||||
| // section 10 | // section 10 | ||||
| bool AscendStreamAssign::IsVecExist(std::vector<uint32_t> *group) { | |||||
| auto group_size = group->size(); | |||||
| bool AscendStreamAssign::IsVecExist(const std::vector<uint32_t> &group) { | |||||
| auto group_size = group.size(); | |||||
| if (group_size == 0) { | if (group_size == 0) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| for (const auto &item : stream_groups_) { | for (const auto &item : stream_groups_) { | ||||
| if (item.size() < group->size()) { | |||||
| if (item.size() < group.size()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| bool flag = true; | bool flag = true; | ||||
| for (size_t i = 0; i < group_size; i++) { | for (size_t i = 0; i < group_size; i++) { | ||||
| if (item[i] != group->at(i)) { | |||||
| if (item[i] != group.at(i)) { | |||||
| flag = false; | flag = false; | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -1469,7 +1470,7 @@ bool AscendStreamAssign::IsVecExist(std::vector<uint32_t> *group) { | |||||
| void AscendStreamAssign::DFS(uint32_t start, std::vector<uint32_t> *group) { | void AscendStreamAssign::DFS(uint32_t start, std::vector<uint32_t> *group) { | ||||
| auto it = stream_relations_.find(start); | auto it = stream_relations_.find(start); | ||||
| if (it == stream_relations_.end()) { | if (it == stream_relations_.end()) { | ||||
| if (!IsVecExist(group)) { | |||||
| if (!IsVecExist(*group)) { | |||||
| stream_groups_.emplace_back(*group); | stream_groups_.emplace_back(*group); | ||||
| } else { | } else { | ||||
| MS_LOG(WARNING) << "DFS find same stream group, Not expected"; | MS_LOG(WARNING) << "DFS find same stream group, Not expected"; | ||||
| @@ -1781,7 +1782,6 @@ void AscendStreamAssign::FindEventRelations(const NotNull<KernelGraphPtr> &graph | |||||
| MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId); | MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId); | ||||
| } | } | ||||
| } | } | ||||
| } // namespace ascend | } // namespace ascend | ||||
| } // namespace device | } // namespace device | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -172,7 +172,7 @@ class AscendStreamAssign { | |||||
| // function for memory resue | // function for memory resue | ||||
| void GetStreamRelations(); | void GetStreamRelations(); | ||||
| void DFS(uint32_t start, std::vector<uint32_t> *group); | void DFS(uint32_t start, std::vector<uint32_t> *group); | ||||
| bool IsVecExist(std::vector<uint32_t> *group); | |||||
| bool IsVecExist(const std::vector<uint32_t> &group); | |||||
| void FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr); | void FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr); | ||||
| void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr); | void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr); | ||||
| void GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index); | void GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index); | ||||