diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index e5a7cc54b0..3998b6c9fe 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -14,10 +14,9 @@ * limitations under the License. */ #include "backend/session/executor.h" +#include "backend/session/executor_manager.h" #include #include - -#include "backend/session/executor_manager.h" #include "runtime/device/kernel_runtime_manager.h" #include "utils/comm_manager.h" #include "utils/scoped_long_running.h" @@ -120,14 +119,15 @@ void RunGraphTask::Run() { session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); UpdateOutputTensors(&outputs_, tensor_to_node_); } catch (const std::exception &e) { - MsException::GetInstance().SetException(); + ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); + MsException::Instance().SetException(); } graph->OnRunGraphFinished(); for (auto &tensor : input_need_lock_tensors_) { tensor->SetNeedWait(false); } NotifyOutputTensors(&outputs_); - ExecutorManager::Instance().OnRunGraphFinished(); + ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished); MS_LOG(INFO) << "End run graph " << graph_id_; } @@ -187,7 +187,8 @@ void Executor::WorkerLoop() { try { task->Run(); } catch (const std::exception &e) { - MsException::GetInstance().SetException(); + ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); + MsException::Instance().SetException(); } { std::unique_lock lock(task_mutex_); @@ -214,6 +215,18 @@ std::vector> Executor::GetNewReadyTasks() { return new_ready_tasks; } +void Executor::OnEvent(const ExecutorEvent &event) { + if (event == ExecutorEvent::kRunGraphFinished) { + OnRunGraphFinished(); + } else if (event == ExecutorEvent::kException) { + std::unique_lock lock(task_mutex_); + while (!ready_tasks_.empty()) { + done_tasks_.emplace_back(ready_tasks_.front()); + ready_tasks_.pop(); + } + } +} + void Executor::OnRunGraphFinished() { auto new_ready_tasks = GetNewReadyTasks(); std::unique_lock lock(task_mutex_); @@ -249,7 +262,7 @@ void Executor::SyncRunTask(const std::shared_ptr &task) { done_tasks_.clear(); task_cond_var_.notify_all(); sync_cond_var_.wait(lock); - MsException::GetInstance().CheckException(); + MsException::Instance().CheckException(); } GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, @@ -311,7 +324,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, } } } - MsException::GetInstance().CheckException(); + MsException::Instance().CheckException(); for (auto &tensor : task->input_need_lock_tensors_) { tensor->SetNeedWait(true); } @@ -332,7 +345,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, mindspore::ScopedLongRunning long_running; std::unique_lock lock(reenter_mutex_); reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); }); - MsException::GetInstance().CheckException(); + MsException::Instance().CheckException(); } } diff --git a/mindspore/ccsrc/backend/session/executor.h b/mindspore/ccsrc/backend/session/executor.h index 91932bfd6c..d0085f5263 100644 --- a/mindspore/ccsrc/backend/session/executor.h +++ b/mindspore/ccsrc/backend/session/executor.h @@ -26,7 +26,6 @@ #include #include #include - #include "backend/session/session_basic.h" #include "ir/anf.h" #include "ir/tensor.h" @@ -156,6 +155,8 @@ class ExitTask : public Task { ~ExitTask() override = default; }; +enum class ExecutorEvent { kClear, kRunGraphFinished, kException }; + class Executor { public: Executor(const std::string &device_name, uint32_t device_id); @@ -176,9 +177,9 @@ class Executor { VectorRef *outputs); void CleanUselessTensors(const SessionPtr &session, const std::shared_ptr> &useless_tensors); - void OnRunGraphFinished(); bool CreateCommGroup(const std::string &group_name, std::vector ranks); bool DestroyCommGroup(const std::string &group_name); + void OnEvent(const ExecutorEvent &event); private: void SyncRunTask(const std::shared_ptr &task); @@ -188,6 +189,7 @@ class Executor { bool IsTaskReady(const std::shared_ptr &task); void CheckException(); void OnWorkerExit(); + void OnRunGraphFinished(); uint32_t device_id_; std::string device_name_; diff --git a/mindspore/ccsrc/backend/session/executor_manager.cc b/mindspore/ccsrc/backend/session/executor_manager.cc index 3758adcf2e..46d34795bf 100644 --- a/mindspore/ccsrc/backend/session/executor_manager.cc +++ b/mindspore/ccsrc/backend/session/executor_manager.cc @@ -28,26 +28,17 @@ std::shared_ptr ExecutorManager::GetExecutor(const std::string &device return executor; } -void ExecutorManager::OnRunGraphFinished() { +void ExecutorManager::OnEvent(const ExecutorEvent &event) { for (auto &item : executors_) { auto &executor = item.second; if (executor != nullptr) { - executor->OnRunGraphFinished(); - } - } -} - -void ExecutorManager::JoinExecutorWorkers() { - for (auto &item : executors_) { - auto &executor = item.second; - if (executor != nullptr) { - executor->WorkerJoin(); + executor->OnEvent(event); } } } void ExecutorManager::Clear() { - JoinExecutorWorkers(); + OnEvent(ExecutorEvent::kClear); executors_.clear(); } } // namespace session diff --git a/mindspore/ccsrc/backend/session/executor_manager.h b/mindspore/ccsrc/backend/session/executor_manager.h index ff876b8673..9d1cbaa095 100644 --- a/mindspore/ccsrc/backend/session/executor_manager.h +++ b/mindspore/ccsrc/backend/session/executor_manager.h @@ -30,14 +30,14 @@ class ExecutorManager { return instance; } std::shared_ptr GetExecutor(const std::string &device_name, int device_id); - void OnRunGraphFinished(); + void OnEvent(const ExecutorEvent &event); void Clear(); private: ExecutorManager() = default; ~ExecutorManager() = default; DISABLE_COPY_AND_ASSIGN(ExecutorManager) - void JoinExecutorWorkers(); + std::map> executors_; }; } // namespace session diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 2d284efe78..f3c763f4d7 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -637,7 +637,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, (*other_graph_cnode)[anf] = new_parameter; } continue; - } else if (optimize_control_depend) { + } else if (optimize_control_depend || IsPrimitiveCNode(anf, prim::kPrimControlDepend)) { cnode_inputs->push_back(NewValueNode(MakeValue(SizeToLong(input_idx)))); } else { // the input node is a cnode from other graph diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index b79d67cc9f..1a52f5014f 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -87,6 +87,10 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo if (node->isa() && !IsValueNode(node)) { eqv[node] = node; } else if (eqv.find(node) == eqv.end()) { + if (IsPrimitiveCNode(node, prim::kPrimControlDepend)) { + eqv[node] = NewValueNode(MakeValue(0)); + return eqv[node]; + } bool ignore_make_tuple = false; if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { ignore_make_tuple = true; diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 8c2bad940e..b2eadd4d1a 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -87,10 +87,10 @@ class WaitEvent : public ExceptionListener { if (!need_wait_) { return; } - MsException::GetInstance().AddExceptionListener(const_cast(this)); + MsException::Instance().AddExceptionListener(const_cast(this)); cond_var_.wait(lock, [this] { return !need_wait_; }); - MsException::GetInstance().CheckException(); - MsException::GetInstance().RemoveExceptionListener(const_cast(this)); + MsException::Instance().CheckException(); + MsException::Instance().RemoveExceptionListener(const_cast(this)); } void set_need_wait(bool need_wait) { diff --git a/mindspore/core/utils/ms_exception.h b/mindspore/core/utils/ms_exception.h index fab3310a5b..8cc013fb47 100644 --- a/mindspore/core/utils/ms_exception.h +++ b/mindspore/core/utils/ms_exception.h @@ -27,7 +27,7 @@ class ExceptionListener { class MsException { public: - static MsException &GetInstance() { + static MsException &Instance() { static MsException instance; return instance; }