|
|
|
@@ -110,6 +110,12 @@ Executor::Executor(const std::string &device_name, uint32_t device_id) { |
|
|
|
worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this); |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::CheckException() { |
|
|
|
if (exception_ptr_ != nullptr) { |
|
|
|
std::rethrow_exception(exception_ptr_); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::WorkerJoin() { |
|
|
|
StopWorker(); |
|
|
|
worker_->join(); |
|
|
|
@@ -128,7 +134,11 @@ void Executor::WorkerLoop() { |
|
|
|
OnWorkerExit(); |
|
|
|
return; |
|
|
|
} |
|
|
|
task->Run(); |
|
|
|
try { |
|
|
|
task->Run(); |
|
|
|
} catch (const std::exception &e) { |
|
|
|
exception_ptr_ = std::current_exception(); |
|
|
|
} |
|
|
|
if (task->type_ == kCompileNodes) { |
|
|
|
compile_cond_var_.notify_all(); |
|
|
|
} else if (task->type_ == kCompileGraph) { |
|
|
|
@@ -183,6 +193,7 @@ bool Executor::IsAllInputsReady(const std::vector<tensor::TensorPtr> &inputs) { |
|
|
|
|
|
|
|
GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrList &lst, |
|
|
|
const AnfNodePtrList &outputs) { |
|
|
|
CheckException(); |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
auto task = std::make_shared<CompileNodesTask>(); |
|
|
|
task->session_ = session; |
|
|
|
@@ -191,10 +202,12 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrL |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
compile_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
return task->graph_id_; |
|
|
|
} |
|
|
|
|
|
|
|
GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) { |
|
|
|
CheckException(); |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
auto task = std::make_shared<CompileGraphTask>(); |
|
|
|
task->session_ = session; |
|
|
|
@@ -202,10 +215,12 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraph |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
compile_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
return task->graph_id_; |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) { |
|
|
|
CheckException(); |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
auto task = std::make_shared<BuildGraphTask>(); |
|
|
|
task->session_ = session; |
|
|
|
@@ -213,10 +228,12 @@ void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) { |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
build_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { |
|
|
|
CheckException(); |
|
|
|
auto task = std::make_shared<RunGraphTask>(); |
|
|
|
task->session_ = session; |
|
|
|
task->graph_id_ = graph_id; |
|
|
|
@@ -237,10 +254,12 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
py::gil_scoped_release release; |
|
|
|
run_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, |
|
|
|
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) { |
|
|
|
CheckException(); |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
auto task = std::make_shared<BuildOpTask>(); |
|
|
|
task->session_ = session; |
|
|
|
@@ -251,10 +270,12 @@ void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, c |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
build_op_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
} |
|
|
|
|
|
|
|
py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, |
|
|
|
const std::vector<tensor::TensorPtr> &input_tensors) { |
|
|
|
CheckException(); |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
auto task = std::make_shared<RunOpTask>(); |
|
|
|
task->session_ = session; |
|
|
|
@@ -264,6 +285,7 @@ py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info |
|
|
|
ready_tasks_.push(task); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
run_op_cond_var_.wait(lock); |
|
|
|
CheckException(); |
|
|
|
|
|
|
|
// Trans output to tuple |
|
|
|
auto output_tensors = TransformBaseRefListToTuple(task->outputs_); |
|
|
|
|