|
|
|
@@ -14,6 +14,7 @@ |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
#include "backend/session/executor.h" |
|
|
|
#include <exception> |
|
|
|
#include "runtime/device/kernel_runtime_manager.h" |
|
|
|
#include "backend/session/executor_manager.h" |
|
|
|
#include "utils/comm_manager.h" |
|
|
|
@@ -40,10 +41,7 @@ void UpdateOutputTensors(const VectorRef *outputs, |
|
|
|
tensor->set_device_address(address); |
|
|
|
} |
|
|
|
if (tensor->NeedSyncDeviceToHostImmediately()) { |
|
|
|
auto tensor_address = tensor->device_address(); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor_address); |
|
|
|
tensor_address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), |
|
|
|
tensor->data_c()); |
|
|
|
tensor->data_sync(false); |
|
|
|
tensor->set_device_address(nullptr); |
|
|
|
tensor->set_sync_status(kNeedSyncHostToDevice); |
|
|
|
} |
|
|
|
@@ -85,7 +83,11 @@ void BuildGraphTask::Run() { |
|
|
|
|
|
|
|
void RunGraphTask::Run() { |
|
|
|
MS_EXCEPTION_IF_NULL(session_); |
|
|
|
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); |
|
|
|
try { |
|
|
|
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); |
|
|
|
} catch (const std::exception &e) { |
|
|
|
MsException::GetInstance().SetException(); |
|
|
|
} |
|
|
|
UpdateOutputTensors(&outputs_, tensor_to_node_); |
|
|
|
for (auto &tensor : input_need_lock_tensors_) { |
|
|
|
tensor->SetNeedWait(false); |
|
|
|
@@ -115,14 +117,6 @@ Executor::Executor(const std::string &device_name, uint32_t device_id) { |
|
|
|
|
|
|
|
Executor::~Executor() { WorkerJoin(); } |
|
|
|
|
|
|
|
void Executor::CheckException() { |
|
|
|
if (exception_ptr_ != nullptr) { |
|
|
|
auto exception_ptr = exception_ptr_; |
|
|
|
exception_ptr_ = nullptr; |
|
|
|
std::rethrow_exception(exception_ptr); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::WorkerJoin() { |
|
|
|
// Avoid worker thread join itself which will cause deadlock |
|
|
|
if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) { |
|
|
|
@@ -152,7 +146,7 @@ void Executor::WorkerLoop() { |
|
|
|
try { |
|
|
|
task->Run(); |
|
|
|
} catch (const std::exception &e) { |
|
|
|
exception_ptr_ = std::current_exception(); |
|
|
|
MsException::GetInstance().SetException(); |
|
|
|
} |
|
|
|
if (task->type_ != kRunGraph || task->sync_run_) { |
|
|
|
task = nullptr; |
|
|
|
@@ -200,48 +194,40 @@ bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
GraphId Executor::CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { |
|
|
|
CheckException(); |
|
|
|
void Executor::SyncRunTask(const std::shared_ptr<Task> &task) { |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
sync_cond_var_.wait(lock); |
|
|
|
MsException::GetInstance().CheckException(); |
|
|
|
} |
|
|
|
|
|
|
|
GraphId Executor::CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { |
|
|
|
auto task = std::make_shared<CompileNodesTask>(); |
|
|
|
task->session_ = session; |
|
|
|
task->nodes_ = lst; |
|
|
|
task->output_nodes_ = outputs; |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
sync_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
SyncRunTask(task); |
|
|
|
return task->graph_id_; |
|
|
|
} |
|
|
|
|
|
|
|
GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) { |
|
|
|
CheckException(); |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
auto task = std::make_shared<CompileGraphTask>(); |
|
|
|
task->session_ = session; |
|
|
|
task->func_graph_ = func_graph; |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
sync_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
SyncRunTask(task); |
|
|
|
return task->graph_id_; |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) { |
|
|
|
CheckException(); |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
auto task = std::make_shared<BuildGraphTask>(); |
|
|
|
task->session_ = session; |
|
|
|
task->graph_id_ = graphId; |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
sync_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
SyncRunTask(task); |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { |
|
|
|
CheckException(); |
|
|
|
MS_EXCEPTION_IF_NULL(session); |
|
|
|
MS_EXCEPTION_IF_NULL(outputs); |
|
|
|
auto task = std::make_shared<RunGraphTask>(); |
|
|
|
@@ -251,30 +237,25 @@ void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_); |
|
|
|
task->outputs_ = *outputs; |
|
|
|
task->sync_run_ = true; |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
mindspore::ScopedLongRunning long_running; |
|
|
|
sync_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
SyncRunTask(task); |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { |
|
|
|
CheckException(); |
|
|
|
MS_EXCEPTION_IF_NULL(session); |
|
|
|
MS_EXCEPTION_IF_NULL(outputs); |
|
|
|
auto task = std::make_shared<RunGraphTask>(); |
|
|
|
task->session_ = session; |
|
|
|
task->graph_id_ = graph_id; |
|
|
|
task->input_tensors_ = inputs; |
|
|
|
task->input_need_lock_tensors_ = session->GetNeedLockInputTensors(graph_id, inputs); |
|
|
|
// lock inputs |
|
|
|
for (auto &tensor : inputs) { |
|
|
|
if (tensor->NeedWait()) { |
|
|
|
task->input_need_wait_tensors_.emplace_back(tensor); |
|
|
|
} |
|
|
|
} |
|
|
|
task->input_need_lock_tensors_ = session->GetNeedLockInputTensors(graph_id, inputs); |
|
|
|
for (auto &tensor : task->input_need_lock_tensors_) { |
|
|
|
tensor->SetNeedWait(true); |
|
|
|
} |
|
|
|
@@ -285,12 +266,8 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
// sync run graph without output tensor(int dataset graph) |
|
|
|
if (!TensorInVector(outputs)) { |
|
|
|
task->sync_run_ = true; |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
mindspore::ScopedLongRunning long_running; |
|
|
|
sync_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
SyncRunTask(task); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -307,54 +284,38 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
|
|
|
|
void Executor::BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, |
|
|
|
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) { |
|
|
|
CheckException(); |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
auto task = std::make_shared<BuildOpTask>(); |
|
|
|
task->session_ = session; |
|
|
|
task->op_run_info_ = op_run_info; |
|
|
|
task->graph_info_ = graph_info; |
|
|
|
task->input_tensors_ = input_tensors; |
|
|
|
task->tensors_mask_ = tensors_mask; |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
sync_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
SyncRunTask(task); |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, |
|
|
|
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) { |
|
|
|
CheckException(); |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
auto task = std::make_shared<RunOpTask>(); |
|
|
|
task->session_ = session; |
|
|
|
task->op_run_info_ = op_run_info; |
|
|
|
task->graph_info_ = graph_info; |
|
|
|
task->input_tensors_ = input_tensors; |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
sync_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
SyncRunTask(task); |
|
|
|
*outputs = task->outputs_; |
|
|
|
} |
|
|
|
|
|
|
|
bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) { |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
auto task = std::make_shared<CreateCommGroupTask>(); |
|
|
|
task->group_name_ = group_name; |
|
|
|
task->ranks_ = ranks; |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
sync_cond_var_.wait(lock); |
|
|
|
SyncRunTask(task); |
|
|
|
return task->result_; |
|
|
|
} |
|
|
|
|
|
|
|
bool Executor::DestroyCommGroup(const std::string &group_name) { |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
auto task = std::make_shared<DestroyCommGroupTask>(); |
|
|
|
task->group_name_ = group_name; |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
sync_cond_var_.wait(lock); |
|
|
|
SyncRunTask(task); |
|
|
|
return task->result_; |
|
|
|
} |
|
|
|
|
|
|
|
|