Browse Source

!392 use GraphId as key in AscendKernelRuntime

Merge pull request !392 from caifubi/use-graphid-as-key-in-ascend-kernel-runtime
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
a51f01f206
2 changed files with 40 additions and 37 deletions
  1. +36
    -33
      mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
  2. +4
    -4
      mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h

+ 36
- 33
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc View File

@@ -54,9 +54,9 @@ static const size_t PRAMATER_OUTPUT_INDEX = 0;
AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); } AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); }


void AscendKernelRuntime::ClearGraphModelMap() { void AscendKernelRuntime::ClearGraphModelMap() {
for (auto &iter : graph_model_id_map_) {
MS_LOG(INFO) << "Ge UnloadModel " << iter.second;
auto ret = ge::model_runner::ModelRunner::Instance().UnloadModel(iter.second);
for (auto &iter : graph_model_map_) {
MS_LOG(INFO) << "Ge UnloadModel " << iter.first;
auto ret = ge::model_runner::ModelRunner::Instance().UnloadModel(iter.first);
if (!ret) { if (!ret) {
MS_LOG(ERROR) << "UnloadModel failed"; MS_LOG(ERROR) << "UnloadModel failed";
} }
@@ -249,6 +249,10 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
} }


bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
if (graph == nullptr) {
MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!";
}
MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id();
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->enable_task_sink(); bool is_task_sink = context_ptr->enable_task_sink();
@@ -261,19 +265,15 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph); mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph);
} }
#endif #endif
if (graph == nullptr) {
MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!";
}
vector<std::shared_ptr<TaskInfo>> task_info_list; vector<std::shared_ptr<TaskInfo>> task_info_list;
auto anf_node_list = graph->execution_order(); auto anf_node_list = graph->execution_order();
TaskGenerator::GenTasks(anf_node_list, &task_info_list, graph->graph_id()); TaskGenerator::GenTasks(anf_node_list, &task_info_list, graph->graph_id());


// Store the task_info_list // Store the task_info_list
auto iter = task_map_.find(graph);
if (iter != task_map_.end()) {
MS_LOG(EXCEPTION) << "graph TaskInfo list already exist";
auto insert_ret = task_map_.insert(std::make_pair(graph->graph_id(), task_info_list));
if (!insert_ret.second) {
MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
} }
task_map_[graph] = task_info_list;


// Graph may have no compute node, such TensorAddGrad. // Graph may have no compute node, such TensorAddGrad.
if (task_info_list.empty()) { if (task_info_list.empty()) {
@@ -296,25 +296,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0,
0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.GetTotalEventNum(), 0); 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.GetTotalEventNum(), 0);


graph_model_map_[graph] = model;
graph_model_id_map_[graph] = graph->graph_id();
auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model));
if (!ret.second) {
MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
}
MS_LOG(INFO) << "TaskGenerator GetTaskInfo end..."; MS_LOG(INFO) << "TaskGenerator GetTaskInfo end...";
return true; return true;
} }


uint32_t AscendKernelRuntime::GetGraphModelId(const session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto iter = graph_model_id_map_.find(kernel_graph);
if (iter == graph_model_id_map_.end()) {
MS_LOG(EXCEPTION) << "graph not in the map";
}
return iter->second;
}

bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
if (graph == nullptr) { if (graph == nullptr) {
MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. "; MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. ";
} }
MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id();
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->enable_task_sink(); bool is_task_sink = context_ptr->enable_task_sink();
@@ -327,23 +321,22 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
return true; return true;
} }


auto task_iter = graph_model_map_.find(graph);
if (task_iter == graph_model_map_.end()) {
MS_LOG(ERROR) << "task not exist";
auto model_iter = graph_model_map_.find(graph->graph_id());
if (model_iter == graph_model_map_.end()) {
MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph LoadTask without GenTask.";
return false; return false;
} }


auto model_id = GetGraphModelId(graph);
std::shared_ptr<ge::ModelListener> listener; std::shared_ptr<ge::ModelListener> listener;
MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_id;
bool status =
ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_id, task_iter->second, listener);
MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first;
bool status = ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first,
model_iter->second, listener);
if (!status) { if (!status) {
MS_LOG(INFO) << "load task failed";
MS_LOG(ERROR) << "load task failed";
return false; return false;
} }
if (ProfilingManager::GetInstance().IsProfiling()) { if (ProfilingManager::GetInstance().IsProfiling()) {
std::vector<uint32_t> task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_id);
std::vector<uint32_t> task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first);
ProfilingUtils::ReportProfilingData(graph->graph_id(), task_ids); ProfilingUtils::ReportProfilingData(graph->graph_id(), task_ids);
} }
return true; return true;
@@ -351,6 +344,8 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {


bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id();

auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
ge::InputData input_tensors = ge::InputData(); ge::InputData input_tensors = ge::InputData();
@@ -360,8 +355,12 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
return true; return true;
} }


auto model_id = GetGraphModelId(graph);
bool status = ge::model_runner::ModelRunner::Instance().RunModel(model_id, input_tensors, output_tensors);
if (!CheckGraphIdValid(graph->graph_id())) {
MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph RunTask without GenTask.";
return false;
}

bool status = ge::model_runner::ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors);
if (!status) { if (!status) {
MS_LOG(INFO) << "run task failed"; MS_LOG(INFO) << "run task failed";
return false; return false;
@@ -497,12 +496,16 @@ bool AscendKernelRuntime::DestroyHccl() {
} }


bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph *graph) const { bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph *graph) const {
auto iter = task_map_.find(graph);
auto iter = task_map_.find(graph->graph_id());
if (iter == task_map_.end()) { if (iter == task_map_.end()) {
MS_LOG(EXCEPTION) << "Unknown graph ptr"; MS_LOG(EXCEPTION) << "Unknown graph ptr";
} }
return iter->second.empty(); return iter->second.empty();
} }

bool AscendKernelRuntime::CheckGraphIdValid(GraphId graph_id) const {
return task_map_.find(graph_id) != task_map_.end() && graph_model_map_.find(graph_id) != graph_model_map_.end();
}
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

+ 4
- 4
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h View File

@@ -23,6 +23,7 @@
#include "runtime/context.h" #include "runtime/context.h"
#include "framework/ge_runtime/davinci_model.h" #include "framework/ge_runtime/davinci_model.h"
#include "device/kernel_runtime_manager.h" #include "device/kernel_runtime_manager.h"
#include "session/session_basic.h"


using ge::model_runner::TaskInfo; using ge::model_runner::TaskInfo;
using std::unordered_map; using std::unordered_map;
@@ -54,14 +55,13 @@ class AscendKernelRuntime : public KernelRuntime {


void ClearGraphModelMap(); void ClearGraphModelMap();
void ReleaseDeviceRes() override; void ReleaseDeviceRes() override;
uint32_t GetGraphModelId(const session::KernelGraph *kernel_graph);
bool GraphWithEmptyTaskList(const session::KernelGraph *graph) const; bool GraphWithEmptyTaskList(const session::KernelGraph *graph) const;
bool CheckGraphIdValid(GraphId graph_id) const;


rtContext_t rt_context_{nullptr}; rtContext_t rt_context_{nullptr};
bool initialized_{false}; bool initialized_{false};
unordered_map<const session::KernelGraph *, vector<std::shared_ptr<TaskInfo>>> task_map_;
unordered_map<const session::KernelGraph *, std::shared_ptr<ge::model_runner::DavinciModel>> graph_model_map_;
unordered_map<const session::KernelGraph *, uint32_t> graph_model_id_map_;
unordered_map<GraphId, vector<std::shared_ptr<TaskInfo>>> task_map_;
unordered_map<GraphId, std::shared_ptr<ge::model_runner::DavinciModel>> graph_model_map_;
}; };


MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime); MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime);


Loading…
Cancel
Save