Browse Source

!6255 fix reviewbot

Merge pull request !6255 from gukecai/codex
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
16f321d1d9
4 changed files with 17 additions and 18 deletions
  1. +6
    -6
      mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.cc
  2. +4
    -5
      mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.h
  3. +6
    -6
      mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc
  4. +1
    -1
      mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h

+ 6
- 6
mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.cc View File

@@ -35,7 +35,7 @@ Status CallbackManager::Init() {
} }


Status CallbackManager::CallbackProcess() { Status CallbackManager::CallbackProcess() {
std::pair<rtEvent_t, std::pair<rtCallback_t, void *>> entry;
std::pair<rtEvent_t, std::pair<rtCallback_t, const void *>> entry;
while (true) { while (true) {
if (!callback_queue_.Pop(&entry)) { if (!callback_queue_.Pop(&entry)) {
MS_LOG(INFO) << "CallbackManager stopped"; MS_LOG(INFO) << "CallbackManager stopped";
@@ -84,7 +84,7 @@ Status CallbackManager::Destroy() {
return ret; return ret;
} }


Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) {
Status CallbackManager::RegisterCallback(rtCallback_t callback, const void *user_data) {
MS_LOG(INFO) << "To register callback"; MS_LOG(INFO) << "To register callback";
rtEvent_t event = nullptr; rtEvent_t event = nullptr;
auto ret = rtEventCreate(&event); auto ret = rtEventCreate(&event);
@@ -98,8 +98,8 @@ Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data)
MS_LOG(ERROR) << "Record event failed"; MS_LOG(ERROR) << "Record event failed";
return kFail; return kFail;
} }
auto cb = std::pair<rtCallback_t, void *>(callback, user_data);
auto entry = std::pair<rtEvent_t, std::pair<rtCallback_t, void *>>(event, std::move(cb));
auto cb = std::pair<rtCallback_t, const void *>(callback, user_data);
auto entry = std::pair<rtEvent_t, std::pair<rtCallback_t, const void *>>(event, std::move(cb));
if (!callback_queue_.Push(entry)) { if (!callback_queue_.Push(entry)) {
return kFail; return kFail;
} }
@@ -108,9 +108,9 @@ Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data)
return kSuccess; return kSuccess;
} }


void CallbackManager::RtCallbackFunc(void *data) {
void CallbackManager::RtCallbackFunc(const void *data) {
MS_LOG(INFO) << "To invoke callback function"; MS_LOG(INFO) << "To invoke callback function";
auto callback_func = reinterpret_cast<std::function<void()> *>(data);
auto callback_func = reinterpret_cast<const std::function<void()> *>(data);
(*callback_func)(); (*callback_func)();
delete callback_func; delete callback_func;
} }


+ 4
- 5
mindspore/ccsrc/profiler/device/ascend/rt_callback_manager.h View File

@@ -24,11 +24,10 @@
#include <utility> #include <utility>
#include "profiler/device/ascend/blocking_queue.h" #include "profiler/device/ascend/blocking_queue.h"
#include "runtime/base.h" #include "runtime/base.h"

namespace mindspore { namespace mindspore {
namespace profiler { namespace profiler {
namespace ascend { namespace ascend {
using rtCallback_t = std::function<void(void *)>;
using rtCallback_t = std::function<void(const void *)>;
enum Status { kSuccess = 0, kFail, kInvalidParam }; enum Status { kSuccess = 0, kFail, kInvalidParam };
class CallbackManager { class CallbackManager {
public: public:
@@ -45,14 +44,14 @@ class CallbackManager {


Status Destroy(); Status Destroy();


Status RegisterCallback(rtCallback_t callback, void *user_data);
Status RegisterCallback(rtCallback_t callback, const void *user_data);
Status RegisterCallback(const std::function<void()> &callback); Status RegisterCallback(const std::function<void()> &callback);


private: private:
Status CallbackProcess(); Status CallbackProcess();
static void RtCallbackFunc(void *data);
static void RtCallbackFunc(const void *data);


BlockingQueue<std::pair<rtEvent_t, std::pair<rtCallback_t, void *>>> callback_queue_;
BlockingQueue<std::pair<rtEvent_t, std::pair<rtCallback_t, const void *>>> callback_queue_;
rtStream_t stream_; rtStream_t stream_;
std::future<Status> ret_future_; std::future<Status> ret_future_;
}; };


+ 6
- 6
mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc View File

@@ -38,6 +38,7 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
Reset(); Reset();
SetLoopSink(); SetLoopSink();
ReorderIndependentOrders(graph_ptr); ReorderIndependentOrders(graph_ptr);

AssignAllNodesStream(graph_ptr); AssignAllNodesStream(graph_ptr);
UpdateAtomicAddrCleanStreamId(graph_ptr); UpdateAtomicAddrCleanStreamId(graph_ptr);
InsertStreamActive(graph_ptr); InsertStreamActive(graph_ptr);
@@ -1438,19 +1439,19 @@ void AscendStreamAssign::Reset() {
} }


// section 10 // section 10
bool AscendStreamAssign::IsVecExist(std::vector<uint32_t> *group) {
auto group_size = group->size();
bool AscendStreamAssign::IsVecExist(const std::vector<uint32_t> &group) {
auto group_size = group.size();
if (group_size == 0) { if (group_size == 0) {
return false; return false;
} }
for (const auto &item : stream_groups_) { for (const auto &item : stream_groups_) {
if (item.size() < group->size()) {
if (item.size() < group.size()) {
continue; continue;
} }


bool flag = true; bool flag = true;
for (size_t i = 0; i < group_size; i++) { for (size_t i = 0; i < group_size; i++) {
if (item[i] != group->at(i)) {
if (item[i] != group.at(i)) {
flag = false; flag = false;
break; break;
} }
@@ -1469,7 +1470,7 @@ bool AscendStreamAssign::IsVecExist(std::vector<uint32_t> *group) {
void AscendStreamAssign::DFS(uint32_t start, std::vector<uint32_t> *group) { void AscendStreamAssign::DFS(uint32_t start, std::vector<uint32_t> *group) {
auto it = stream_relations_.find(start); auto it = stream_relations_.find(start);
if (it == stream_relations_.end()) { if (it == stream_relations_.end()) {
if (!IsVecExist(group)) {
if (!IsVecExist(*group)) {
stream_groups_.emplace_back(*group); stream_groups_.emplace_back(*group);
} else { } else {
MS_LOG(WARNING) << "DFS find same stream group, Not expected"; MS_LOG(WARNING) << "DFS find same stream group, Not expected";
@@ -1781,7 +1782,6 @@ void AscendStreamAssign::FindEventRelations(const NotNull<KernelGraphPtr> &graph
MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId); MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId);
} }
} }

} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

+ 1
- 1
mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h View File

@@ -172,7 +172,7 @@ class AscendStreamAssign {
// function for memory resue // function for memory resue
void GetStreamRelations(); void GetStreamRelations();
void DFS(uint32_t start, std::vector<uint32_t> *group); void DFS(uint32_t start, std::vector<uint32_t> *group);
bool IsVecExist(std::vector<uint32_t> *group);
bool IsVecExist(const std::vector<uint32_t> &group);
void FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr); void FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr);
void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr); void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr);
void GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index); void GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index);


Loading…
Cancel
Save