|
|
|
@@ -260,6 +260,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { |
|
|
|
auto anf_node_list = graph->execution_order(); |
|
|
|
TaskGenerator::GenTasks(anf_node_list, &task_info_list, graph->graph_id()); |
|
|
|
|
|
|
|
// Store the task_info_list |
|
|
|
auto iter = task_map_.find(graph); |
|
|
|
if (iter != task_map_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "graph TaskInfo list already exist"; |
|
|
|
} |
|
|
|
task_map_[graph] = task_info_list; |
|
|
|
|
|
|
|
// Graph may have no compute node, such TensorAddGrad. |
|
|
|
if (task_info_list.empty()) { |
|
|
|
MS_LOG(WARNING) << "graph " << graph->graph_id() << " have no compute node"; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); |
|
|
|
// the streams' flag not HEAD_STREAM |
|
|
|
std::vector<uint32_t> wait_active_stream_list = assign_instance.GetWaitStreams(); |
|
|
|
@@ -278,10 +291,6 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { |
|
|
|
graph_model_map_[graph] = model; |
|
|
|
graph_model_id_map_[graph] = graph->graph_id(); |
|
|
|
MS_LOG(INFO) << "TaskGenerator GetTaskInfo end..."; |
|
|
|
|
|
|
|
// Store the task_info_list |
|
|
|
task_map_.insert(std::make_pair(graph, task_info_list)); |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -305,6 +314,11 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
if (GraphWithEmptyTaskList(graph)) { |
|
|
|
MS_LOG(WARNING) << "LoadTask end, task list is empty"; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
auto task_iter = graph_model_map_.find(graph); |
|
|
|
if (task_iter == graph_model_map_.end()) { |
|
|
|
MS_LOG(ERROR) << "task not exist"; |
|
|
|
@@ -333,6 +347,11 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
ge::InputData input_tensors = ge::InputData(); |
|
|
|
ge::OutputData *output_tensors = nullptr; |
|
|
|
if (GraphWithEmptyTaskList(graph)) { |
|
|
|
MS_LOG(WARNING) << "RunTask end, no task info found"; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
auto model_id = GetGraphModelId(graph); |
|
|
|
bool status = ge::model_runner::ModelRunner::Instance().RunModel(model_id, input_tensors, output_tensors); |
|
|
|
if (!status) { |
|
|
|
@@ -468,6 +487,14 @@ bool AscendKernelRuntime::DestroyHccl() { |
|
|
|
context_ptr->set_enable_hccl(false); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph *graph) const { |
|
|
|
auto iter = task_map_.find(graph); |
|
|
|
if (iter == task_map_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Unknown graph ptr"; |
|
|
|
} |
|
|
|
return iter->second.empty(); |
|
|
|
} |
|
|
|
} // namespace ascend |
|
|
|
} // namespace device |
|
|
|
} // namespace mindspore |