| @@ -14,10 +14,9 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/session/executor.h" | |||
| #include "backend/session/executor_manager.h" | |||
| #include <algorithm> | |||
| #include <exception> | |||
| #include "backend/session/executor_manager.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "utils/comm_manager.h" | |||
| #include "utils/scoped_long_running.h" | |||
| @@ -120,14 +119,15 @@ void RunGraphTask::Run() { | |||
| session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); | |||
| UpdateOutputTensors(&outputs_, tensor_to_node_); | |||
| } catch (const std::exception &e) { | |||
| MsException::GetInstance().SetException(); | |||
| ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); | |||
| MsException::Instance().SetException(); | |||
| } | |||
| graph->OnRunGraphFinished(); | |||
| for (auto &tensor : input_need_lock_tensors_) { | |||
| tensor->SetNeedWait(false); | |||
| } | |||
| NotifyOutputTensors(&outputs_); | |||
| ExecutorManager::Instance().OnRunGraphFinished(); | |||
| ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished); | |||
| MS_LOG(INFO) << "End run graph " << graph_id_; | |||
| } | |||
| @@ -187,7 +187,8 @@ void Executor::WorkerLoop() { | |||
| try { | |||
| task->Run(); | |||
| } catch (const std::exception &e) { | |||
| MsException::GetInstance().SetException(); | |||
| ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); | |||
| MsException::Instance().SetException(); | |||
| } | |||
| { | |||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||
| @@ -214,6 +215,18 @@ std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() { | |||
| return new_ready_tasks; | |||
| } | |||
| void Executor::OnEvent(const ExecutorEvent &event) { | |||
| if (event == ExecutorEvent::kRunGraphFinished) { | |||
| OnRunGraphFinished(); | |||
| } else if (event == ExecutorEvent::kException) { | |||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||
| while (!ready_tasks_.empty()) { | |||
| done_tasks_.emplace_back(ready_tasks_.front()); | |||
| ready_tasks_.pop(); | |||
| } | |||
| } | |||
| } | |||
| void Executor::OnRunGraphFinished() { | |||
| auto new_ready_tasks = GetNewReadyTasks(); | |||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||
| @@ -249,7 +262,7 @@ void Executor::SyncRunTask(const std::shared_ptr<Task> &task) { | |||
| done_tasks_.clear(); | |||
| task_cond_var_.notify_all(); | |||
| sync_cond_var_.wait(lock); | |||
| MsException::GetInstance().CheckException(); | |||
| MsException::Instance().CheckException(); | |||
| } | |||
| GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, | |||
| @@ -311,7 +324,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, | |||
| } | |||
| } | |||
| } | |||
| MsException::GetInstance().CheckException(); | |||
| MsException::Instance().CheckException(); | |||
| for (auto &tensor : task->input_need_lock_tensors_) { | |||
| tensor->SetNeedWait(true); | |||
| } | |||
| @@ -332,7 +345,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, | |||
| mindspore::ScopedLongRunning long_running; | |||
| std::unique_lock<std::mutex> lock(reenter_mutex_); | |||
| reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); }); | |||
| MsException::GetInstance().CheckException(); | |||
| MsException::Instance().CheckException(); | |||
| } | |||
| } | |||
| @@ -26,7 +26,6 @@ | |||
| #include <thread> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "backend/session/session_basic.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/tensor.h" | |||
| @@ -156,6 +155,8 @@ class ExitTask : public Task { | |||
| ~ExitTask() override = default; | |||
| }; | |||
| enum class ExecutorEvent { kClear, kRunGraphFinished, kException }; | |||
| class Executor { | |||
| public: | |||
| Executor(const std::string &device_name, uint32_t device_id); | |||
| @@ -176,9 +177,9 @@ class Executor { | |||
| VectorRef *outputs); | |||
| void CleanUselessTensors(const SessionPtr &session, | |||
| const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors); | |||
| void OnRunGraphFinished(); | |||
| bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks); | |||
| bool DestroyCommGroup(const std::string &group_name); | |||
| void OnEvent(const ExecutorEvent &event); | |||
| private: | |||
| void SyncRunTask(const std::shared_ptr<Task> &task); | |||
| @@ -188,6 +189,7 @@ class Executor { | |||
| bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task); | |||
| void CheckException(); | |||
| void OnWorkerExit(); | |||
| void OnRunGraphFinished(); | |||
| uint32_t device_id_; | |||
| std::string device_name_; | |||
| @@ -28,26 +28,17 @@ std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device | |||
| return executor; | |||
| } | |||
| void ExecutorManager::OnRunGraphFinished() { | |||
| void ExecutorManager::OnEvent(const ExecutorEvent &event) { | |||
| for (auto &item : executors_) { | |||
| auto &executor = item.second; | |||
| if (executor != nullptr) { | |||
| executor->OnRunGraphFinished(); | |||
| } | |||
| } | |||
| } | |||
| void ExecutorManager::JoinExecutorWorkers() { | |||
| for (auto &item : executors_) { | |||
| auto &executor = item.second; | |||
| if (executor != nullptr) { | |||
| executor->WorkerJoin(); | |||
| executor->OnEvent(event); | |||
| } | |||
| } | |||
| } | |||
| void ExecutorManager::Clear() { | |||
| JoinExecutorWorkers(); | |||
| OnEvent(ExecutorEvent::kClear); | |||
| executors_.clear(); | |||
| } | |||
| } // namespace session | |||
| @@ -30,14 +30,14 @@ class ExecutorManager { | |||
| return instance; | |||
| } | |||
| std::shared_ptr<Executor> GetExecutor(const std::string &device_name, int device_id); | |||
| void OnRunGraphFinished(); | |||
| void OnEvent(const ExecutorEvent &event); | |||
| void Clear(); | |||
| private: | |||
| ExecutorManager() = default; | |||
| ~ExecutorManager() = default; | |||
| DISABLE_COPY_AND_ASSIGN(ExecutorManager) | |||
| void JoinExecutorWorkers(); | |||
| std::map<std::string, std::shared_ptr<Executor>> executors_; | |||
| }; | |||
| } // namespace session | |||
| @@ -637,7 +637,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, | |||
| (*other_graph_cnode)[anf] = new_parameter; | |||
| } | |||
| continue; | |||
| } else if (optimize_control_depend) { | |||
| } else if (optimize_control_depend || IsPrimitiveCNode(anf, prim::kPrimControlDepend)) { | |||
| cnode_inputs->push_back(NewValueNode(MakeValue(SizeToLong(input_idx)))); | |||
| } else { | |||
| // the input node is a cnode from other graph | |||
| @@ -87,6 +87,10 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo | |||
| if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) { | |||
| eqv[node] = node; | |||
| } else if (eqv.find(node) == eqv.end()) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimControlDepend)) { | |||
| eqv[node] = NewValueNode(MakeValue(0)); | |||
| return eqv[node]; | |||
| } | |||
| bool ignore_make_tuple = false; | |||
| if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | |||
| ignore_make_tuple = true; | |||
| @@ -87,10 +87,10 @@ class WaitEvent : public ExceptionListener { | |||
| if (!need_wait_) { | |||
| return; | |||
| } | |||
| MsException::GetInstance().AddExceptionListener(const_cast<WaitEvent *>(this)); | |||
| MsException::Instance().AddExceptionListener(const_cast<WaitEvent *>(this)); | |||
| cond_var_.wait(lock, [this] { return !need_wait_; }); | |||
| MsException::GetInstance().CheckException(); | |||
| MsException::GetInstance().RemoveExceptionListener(const_cast<WaitEvent *>(this)); | |||
| MsException::Instance().CheckException(); | |||
| MsException::Instance().RemoveExceptionListener(const_cast<WaitEvent *>(this)); | |||
| } | |||
| void set_need_wait(bool need_wait) { | |||
| @@ -27,7 +27,7 @@ class ExceptionListener { | |||
| class MsException { | |||
| public: | |||
| static MsException &GetInstance() { | |||
| static MsException &Instance() { | |||
| static MsException instance; | |||
| return instance; | |||
| } | |||