|
|
|
@@ -277,34 +277,39 @@ void Executor::ClearDoneTasks() { |
|
|
|
done_tasks_.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync) { |
|
|
|
void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run) { |
|
|
|
{ |
|
|
|
std::lock_guard<std::mutex> lock(task_mutex_); |
|
|
|
ready_tasks_.push(task); |
|
|
|
} |
|
|
|
sync_run_task_finished_ = false; |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
ClearDoneTasks(); |
|
|
|
if (sync && !sync_run_task_finished_) { |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
sync_cond_var_.wait(lock, [this] { |
|
|
|
bool finished = sync_run_task_finished_; |
|
|
|
return finished; |
|
|
|
}); |
|
|
|
if (long_run) { |
|
|
|
mindspore::ScopedLongRunning long_running; |
|
|
|
sync_cond_var_.wait(lock, [this] { |
|
|
|
bool finished = sync_run_task_finished_; |
|
|
|
return finished; |
|
|
|
}); |
|
|
|
} else { |
|
|
|
sync_cond_var_.wait(lock, [this] { |
|
|
|
bool finished = sync_run_task_finished_; |
|
|
|
return finished; |
|
|
|
}); |
|
|
|
} |
|
|
|
} |
|
|
|
ClearDoneTasks(); |
|
|
|
MsException::Instance().CheckException(); |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::SyncRunTask(const std::shared_ptr<Task> &task) { RunTask(task, true); } |
|
|
|
|
|
|
|
GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, |
|
|
|
const AnfNodePtrList &outputs) { |
|
|
|
auto task = std::make_shared<CompileNodesTask>(); |
|
|
|
task->session_ = session; |
|
|
|
task->segment_ = segment; |
|
|
|
task->output_nodes_ = outputs; |
|
|
|
SyncRunTask(task); |
|
|
|
RunTask(task, true); |
|
|
|
return task->graph_id_; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -312,7 +317,7 @@ GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> |
|
|
|
auto task = std::make_shared<CompileGraphTask>(); |
|
|
|
task->session_ = session; |
|
|
|
task->func_graph_ = func_graph.get(); |
|
|
|
SyncRunTask(task); |
|
|
|
RunTask(task, true); |
|
|
|
return task->graph_id_; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -320,7 +325,7 @@ void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) { |
|
|
|
auto task = std::make_shared<BuildGraphTask>(); |
|
|
|
task->session_ = session; |
|
|
|
task->graph_id_ = graphId; |
|
|
|
SyncRunTask(task); |
|
|
|
RunTask(task, true); |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
@@ -334,8 +339,7 @@ void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_); |
|
|
|
task->outputs_ = *outputs; |
|
|
|
task->sync_run_ = true; |
|
|
|
mindspore::ScopedLongRunning long_running; |
|
|
|
SyncRunTask(task); |
|
|
|
RunTask(task, true, true); |
|
|
|
} |
|
|
|
|
|
|
|
void Executor::WaitTaskGraphAvailable(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task) { |
|
|
|
@@ -350,7 +354,6 @@ void Executor::WaitTaskGraphAvailable(const SessionPtr &session, const std::shar |
|
|
|
} |
|
|
|
} |
|
|
|
if (need_lock) { |
|
|
|
ClearDoneTasks(); |
|
|
|
mindspore::ScopedLongRunning long_running; |
|
|
|
for (auto &tensor : task->input_tensors_) { |
|
|
|
if (tensor->NeedWait() && !tensor->IsGraphOutput()) { |
|
|
|
@@ -365,7 +368,6 @@ void Executor::WaitTaskGraphAvailable(const SessionPtr &session, const std::shar |
|
|
|
} |
|
|
|
auto graph = session->GetGraph(task->graph_id_); |
|
|
|
if (graph != nullptr && !graph->IsPostGraphFinished()) { |
|
|
|
ClearDoneTasks(); |
|
|
|
mindspore::ScopedLongRunning long_running; |
|
|
|
std::unique_lock<std::mutex> lock(reenter_mutex_); |
|
|
|
reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); }); |
|
|
|
@@ -388,8 +390,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
// sync run graph without output tensor(int dataset graph) |
|
|
|
if (!TensorInVector(outputs)) { |
|
|
|
task->sync_run_ = true; |
|
|
|
mindspore::ScopedLongRunning long_running; |
|
|
|
SyncRunTask(task); |
|
|
|
RunTask(task, true, true); |
|
|
|
return; |
|
|
|
} |
|
|
|
WaitTaskGraphAvailable(session, task); |
|
|
|
@@ -415,8 +416,7 @@ void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const Gr |
|
|
|
tensor->Wait(); |
|
|
|
} |
|
|
|
} |
|
|
|
mindspore::ScopedLongRunning long_running; |
|
|
|
SyncRunTask(task); |
|
|
|
RunTask(task, true, true); |
|
|
|
*outputs = task->outputs_; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -428,7 +428,7 @@ void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, |
|
|
|
task->session_ = session; |
|
|
|
task->graph_id_ = graph_id; |
|
|
|
task->input_tensors_ = inputs; |
|
|
|
SyncRunTask(task); |
|
|
|
RunTask(task, true); |
|
|
|
*outputs = task->outputs_; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -436,14 +436,14 @@ bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32 |
|
|
|
auto task = std::make_shared<CreateCommGroupTask>(); |
|
|
|
task->group_name_ = group_name; |
|
|
|
task->ranks_ = ranks; |
|
|
|
SyncRunTask(task); |
|
|
|
RunTask(task, true); |
|
|
|
return task->result_; |
|
|
|
} |
|
|
|
|
|
|
|
bool Executor::DestroyCommGroup(const std::string &group_name) { |
|
|
|
auto task = std::make_shared<DestroyCommGroupTask>(); |
|
|
|
task->group_name_ = group_name; |
|
|
|
SyncRunTask(task); |
|
|
|
RunTask(task, true); |
|
|
|
return task->result_; |
|
|
|
} |
|
|
|
|
|
|
|
|