| @@ -14,10 +14,9 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "backend/session/executor.h" | #include "backend/session/executor.h" | ||||
| #include "backend/session/executor_manager.h" | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <exception> | #include <exception> | ||||
| #include "backend/session/executor_manager.h" | |||||
| #include "runtime/device/kernel_runtime_manager.h" | #include "runtime/device/kernel_runtime_manager.h" | ||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/scoped_long_running.h" | #include "utils/scoped_long_running.h" | ||||
| @@ -120,14 +119,15 @@ void RunGraphTask::Run() { | |||||
| session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); | session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); | ||||
| UpdateOutputTensors(&outputs_, tensor_to_node_); | UpdateOutputTensors(&outputs_, tensor_to_node_); | ||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| MsException::GetInstance().SetException(); | |||||
| ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); | |||||
| MsException::Instance().SetException(); | |||||
| } | } | ||||
| graph->OnRunGraphFinished(); | graph->OnRunGraphFinished(); | ||||
| for (auto &tensor : input_need_lock_tensors_) { | for (auto &tensor : input_need_lock_tensors_) { | ||||
| tensor->SetNeedWait(false); | tensor->SetNeedWait(false); | ||||
| } | } | ||||
| NotifyOutputTensors(&outputs_); | NotifyOutputTensors(&outputs_); | ||||
| ExecutorManager::Instance().OnRunGraphFinished(); | |||||
| ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished); | |||||
| MS_LOG(INFO) << "End run graph " << graph_id_; | MS_LOG(INFO) << "End run graph " << graph_id_; | ||||
| } | } | ||||
| @@ -187,7 +187,8 @@ void Executor::WorkerLoop() { | |||||
| try { | try { | ||||
| task->Run(); | task->Run(); | ||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| MsException::GetInstance().SetException(); | |||||
| ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); | |||||
| MsException::Instance().SetException(); | |||||
| } | } | ||||
| { | { | ||||
| std::unique_lock<std::mutex> lock(task_mutex_); | std::unique_lock<std::mutex> lock(task_mutex_); | ||||
| @@ -214,6 +215,18 @@ std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() { | |||||
| return new_ready_tasks; | return new_ready_tasks; | ||||
| } | } | ||||
| void Executor::OnEvent(const ExecutorEvent &event) { | |||||
| if (event == ExecutorEvent::kRunGraphFinished) { | |||||
| OnRunGraphFinished(); | |||||
| } else if (event == ExecutorEvent::kException) { | |||||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||||
| while (!ready_tasks_.empty()) { | |||||
| done_tasks_.emplace_back(ready_tasks_.front()); | |||||
| ready_tasks_.pop(); | |||||
| } | |||||
| } | |||||
| } | |||||
| void Executor::OnRunGraphFinished() { | void Executor::OnRunGraphFinished() { | ||||
| auto new_ready_tasks = GetNewReadyTasks(); | auto new_ready_tasks = GetNewReadyTasks(); | ||||
| std::unique_lock<std::mutex> lock(task_mutex_); | std::unique_lock<std::mutex> lock(task_mutex_); | ||||
| @@ -249,7 +262,7 @@ void Executor::SyncRunTask(const std::shared_ptr<Task> &task) { | |||||
| done_tasks_.clear(); | done_tasks_.clear(); | ||||
| task_cond_var_.notify_all(); | task_cond_var_.notify_all(); | ||||
| sync_cond_var_.wait(lock); | sync_cond_var_.wait(lock); | ||||
| MsException::GetInstance().CheckException(); | |||||
| MsException::Instance().CheckException(); | |||||
| } | } | ||||
| GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, | GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, | ||||
| @@ -311,7 +324,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| MsException::GetInstance().CheckException(); | |||||
| MsException::Instance().CheckException(); | |||||
| for (auto &tensor : task->input_need_lock_tensors_) { | for (auto &tensor : task->input_need_lock_tensors_) { | ||||
| tensor->SetNeedWait(true); | tensor->SetNeedWait(true); | ||||
| } | } | ||||
| @@ -332,7 +345,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, | |||||
| mindspore::ScopedLongRunning long_running; | mindspore::ScopedLongRunning long_running; | ||||
| std::unique_lock<std::mutex> lock(reenter_mutex_); | std::unique_lock<std::mutex> lock(reenter_mutex_); | ||||
| reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); }); | reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); }); | ||||
| MsException::GetInstance().CheckException(); | |||||
| MsException::Instance().CheckException(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -26,7 +26,6 @@ | |||||
| #include <thread> | #include <thread> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "backend/session/session_basic.h" | #include "backend/session/session_basic.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| @@ -156,6 +155,8 @@ class ExitTask : public Task { | |||||
| ~ExitTask() override = default; | ~ExitTask() override = default; | ||||
| }; | }; | ||||
| enum class ExecutorEvent { kClear, kRunGraphFinished, kException }; | |||||
| class Executor { | class Executor { | ||||
| public: | public: | ||||
| Executor(const std::string &device_name, uint32_t device_id); | Executor(const std::string &device_name, uint32_t device_id); | ||||
| @@ -176,9 +177,9 @@ class Executor { | |||||
| VectorRef *outputs); | VectorRef *outputs); | ||||
| void CleanUselessTensors(const SessionPtr &session, | void CleanUselessTensors(const SessionPtr &session, | ||||
| const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors); | const std::shared_ptr<std::vector<tensor::TensorPtr>> &useless_tensors); | ||||
| void OnRunGraphFinished(); | |||||
| bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks); | bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks); | ||||
| bool DestroyCommGroup(const std::string &group_name); | bool DestroyCommGroup(const std::string &group_name); | ||||
| void OnEvent(const ExecutorEvent &event); | |||||
| private: | private: | ||||
| void SyncRunTask(const std::shared_ptr<Task> &task); | void SyncRunTask(const std::shared_ptr<Task> &task); | ||||
| @@ -188,6 +189,7 @@ class Executor { | |||||
| bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task); | bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task); | ||||
| void CheckException(); | void CheckException(); | ||||
| void OnWorkerExit(); | void OnWorkerExit(); | ||||
| void OnRunGraphFinished(); | |||||
| uint32_t device_id_; | uint32_t device_id_; | ||||
| std::string device_name_; | std::string device_name_; | ||||
| @@ -28,26 +28,17 @@ std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device | |||||
| return executor; | return executor; | ||||
| } | } | ||||
| void ExecutorManager::OnRunGraphFinished() { | |||||
| void ExecutorManager::OnEvent(const ExecutorEvent &event) { | |||||
| for (auto &item : executors_) { | for (auto &item : executors_) { | ||||
| auto &executor = item.second; | auto &executor = item.second; | ||||
| if (executor != nullptr) { | if (executor != nullptr) { | ||||
| executor->OnRunGraphFinished(); | |||||
| } | |||||
| } | |||||
| } | |||||
| void ExecutorManager::JoinExecutorWorkers() { | |||||
| for (auto &item : executors_) { | |||||
| auto &executor = item.second; | |||||
| if (executor != nullptr) { | |||||
| executor->WorkerJoin(); | |||||
| executor->OnEvent(event); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void ExecutorManager::Clear() { | void ExecutorManager::Clear() { | ||||
| JoinExecutorWorkers(); | |||||
| OnEvent(ExecutorEvent::kClear); | |||||
| executors_.clear(); | executors_.clear(); | ||||
| } | } | ||||
| } // namespace session | } // namespace session | ||||
| @@ -30,14 +30,14 @@ class ExecutorManager { | |||||
| return instance; | return instance; | ||||
| } | } | ||||
| std::shared_ptr<Executor> GetExecutor(const std::string &device_name, int device_id); | std::shared_ptr<Executor> GetExecutor(const std::string &device_name, int device_id); | ||||
| void OnRunGraphFinished(); | |||||
| void OnEvent(const ExecutorEvent &event); | |||||
| void Clear(); | void Clear(); | ||||
| private: | private: | ||||
| ExecutorManager() = default; | ExecutorManager() = default; | ||||
| ~ExecutorManager() = default; | ~ExecutorManager() = default; | ||||
| DISABLE_COPY_AND_ASSIGN(ExecutorManager) | DISABLE_COPY_AND_ASSIGN(ExecutorManager) | ||||
| void JoinExecutorWorkers(); | |||||
| std::map<std::string, std::shared_ptr<Executor>> executors_; | std::map<std::string, std::shared_ptr<Executor>> executors_; | ||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| @@ -637,7 +637,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, | |||||
| (*other_graph_cnode)[anf] = new_parameter; | (*other_graph_cnode)[anf] = new_parameter; | ||||
| } | } | ||||
| continue; | continue; | ||||
| } else if (optimize_control_depend) { | |||||
| } else if (optimize_control_depend || IsPrimitiveCNode(anf, prim::kPrimControlDepend)) { | |||||
| cnode_inputs->push_back(NewValueNode(MakeValue(SizeToLong(input_idx)))); | cnode_inputs->push_back(NewValueNode(MakeValue(SizeToLong(input_idx)))); | ||||
| } else { | } else { | ||||
| // the input node is a cnode from other graph | // the input node is a cnode from other graph | ||||
| @@ -87,6 +87,10 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo | |||||
| if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) { | if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) { | ||||
| eqv[node] = node; | eqv[node] = node; | ||||
| } else if (eqv.find(node) == eqv.end()) { | } else if (eqv.find(node) == eqv.end()) { | ||||
| if (IsPrimitiveCNode(node, prim::kPrimControlDepend)) { | |||||
| eqv[node] = NewValueNode(MakeValue(0)); | |||||
| return eqv[node]; | |||||
| } | |||||
| bool ignore_make_tuple = false; | bool ignore_make_tuple = false; | ||||
| if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | ||||
| ignore_make_tuple = true; | ignore_make_tuple = true; | ||||
| @@ -87,10 +87,10 @@ class WaitEvent : public ExceptionListener { | |||||
| if (!need_wait_) { | if (!need_wait_) { | ||||
| return; | return; | ||||
| } | } | ||||
| MsException::GetInstance().AddExceptionListener(const_cast<WaitEvent *>(this)); | |||||
| MsException::Instance().AddExceptionListener(const_cast<WaitEvent *>(this)); | |||||
| cond_var_.wait(lock, [this] { return !need_wait_; }); | cond_var_.wait(lock, [this] { return !need_wait_; }); | ||||
| MsException::GetInstance().CheckException(); | |||||
| MsException::GetInstance().RemoveExceptionListener(const_cast<WaitEvent *>(this)); | |||||
| MsException::Instance().CheckException(); | |||||
| MsException::Instance().RemoveExceptionListener(const_cast<WaitEvent *>(this)); | |||||
| } | } | ||||
| void set_need_wait(bool need_wait) { | void set_need_wait(bool need_wait) { | ||||
| @@ -27,7 +27,7 @@ class ExceptionListener { | |||||
| class MsException { | class MsException { | ||||
| public: | public: | ||||
| static MsException &GetInstance() { | |||||
| static MsException &Instance() { | |||||
| static MsException instance; | static MsException instance; | ||||
| return instance; | return instance; | ||||
| } | } | ||||