From: @kisnwang Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -122,13 +122,13 @@ void RunGraphTask::Run() { | |||||
| ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); | ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); | ||||
| MsException::Instance().SetException(); | MsException::Instance().SetException(); | ||||
| } | } | ||||
| MS_LOG(INFO) << "End run graph " << graph_id_; | |||||
| graph->OnRunGraphFinished(); | graph->OnRunGraphFinished(); | ||||
| for (auto &tensor : input_need_lock_tensors_) { | for (auto &tensor : input_need_lock_tensors_) { | ||||
| tensor->SetNeedWait(false); | tensor->SetNeedWait(false); | ||||
| } | } | ||||
| NotifyOutputTensors(&outputs_); | NotifyOutputTensors(&outputs_); | ||||
| ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished); | ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished); | ||||
| MS_LOG(INFO) << "End run graph " << graph_id_; | |||||
| } | } | ||||
| void RunOpTask::Run() { | void RunOpTask::Run() { | ||||
| @@ -213,11 +213,22 @@ std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() { | |||||
| void Executor::OnEvent(const ExecutorEvent &event) { | void Executor::OnEvent(const ExecutorEvent &event) { | ||||
| if (event == ExecutorEvent::kRunGraphFinished) { | if (event == ExecutorEvent::kRunGraphFinished) { | ||||
| OnRunGraphFinished(); | OnRunGraphFinished(); | ||||
| } else if (event == ExecutorEvent::kClear) { | |||||
| WorkerJoin(); | |||||
| } else if (event == ExecutorEvent::kException) { | } else if (event == ExecutorEvent::kException) { | ||||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||||
| while (!ready_tasks_.empty()) { | |||||
| done_tasks_.emplace_back(ready_tasks_.front()); | |||||
| ready_tasks_.pop(); | |||||
| { | |||||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||||
| while (!ready_tasks_.empty()) { | |||||
| done_tasks_.emplace_back(ready_tasks_.front()); | |||||
| ready_tasks_.pop(); | |||||
| } | |||||
| } | |||||
| { | |||||
| std::unique_lock<std::mutex> lock(pending_task_mutex_); | |||||
| for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) { | |||||
| done_tasks_.emplace_back(*iter); | |||||
| } | |||||
| pending_tasks_.clear(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "backend/session/executor_manager.h" | #include "backend/session/executor_manager.h" | ||||
| #include "common/thread_pool.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device_name, int device_id) { | std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device_name, int device_id) { | ||||
| @@ -40,6 +40,7 @@ void ExecutorManager::OnEvent(const ExecutorEvent &event) { | |||||
| void ExecutorManager::Clear() { | void ExecutorManager::Clear() { | ||||
| OnEvent(ExecutorEvent::kClear); | OnEvent(ExecutorEvent::kClear); | ||||
| executors_.clear(); | executors_.clear(); | ||||
| common::ThreadPool::GetInstance().ClearThreadPool(); | |||||
| } | } | ||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,7 +26,7 @@ namespace common { | |||||
| #ifdef ENABLE_D | #ifdef ENABLE_D | ||||
| const int kDeviceNum = 8; | const int kDeviceNum = 8; | ||||
| #endif | #endif | ||||
| const int kMaxThreadNum = 23; | |||||
| bool Queue::Enqueue(Task *task) { | bool Queue::Enqueue(Task *task) { | ||||
| const int tail_index = tail_.load(std::memory_order_relaxed); | const int tail_index = tail_.load(std::memory_order_relaxed); | ||||
| // queue full | // queue full | ||||
| @@ -64,11 +64,15 @@ ThreadPool::ThreadPool() { | |||||
| #else | #else | ||||
| max_thread_num_ = process_core_num; | max_thread_num_ = process_core_num; | ||||
| #endif | #endif | ||||
| SetThreadPool(core_thread_num_); | |||||
| if (max_thread_num_ < 1) { | |||||
| max_thread_num_ = 1; | |||||
| } | |||||
| if (max_thread_num_ > kMaxThreadNum) { | |||||
| max_thread_num_ = kMaxThreadNum; | |||||
| } | |||||
| } | } | ||||
| bool ThreadPool::SetThreadPool(int config_thread_num) { | bool ThreadPool::SetThreadPool(int config_thread_num) { | ||||
| std::lock_guard<std::mutex> Lock(pool_mtx_); | |||||
| if (config_thread_num > max_thread_num_) { | if (config_thread_num > max_thread_num_) { | ||||
| MS_LOG(EXCEPTION) << "Expected thread num is greater than the max thread num, expected thread num=" | MS_LOG(EXCEPTION) << "Expected thread num is greater than the max thread num, expected thread num=" | ||||
| << config_thread_num << ", allowed max thread num=" << max_thread_num_; | << config_thread_num << ", allowed max thread num=" << max_thread_num_; | ||||
| @@ -142,7 +146,65 @@ void ThreadPool::SubRunThread(int num) { | |||||
| cur_thread_run_nums_ = num; | cur_thread_run_nums_ = num; | ||||
| } | } | ||||
| void ThreadPool::SyncRunLoop() { | |||||
| while (true) { | |||||
| Task task; | |||||
| { | |||||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||||
| task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; }); | |||||
| if (exit_run_) { | |||||
| return; | |||||
| } | |||||
| task = task_queue_.front(); | |||||
| task_queue_.pop(); | |||||
| } | |||||
| try { | |||||
| task(); | |||||
| } catch (std::exception &e) { | |||||
| MsException::Instance().SetException(); | |||||
| } | |||||
| { | |||||
| std::unique_lock<std::mutex> task_lock(task_mutex_); | |||||
| task_finished_count_ = task_finished_count_ + 1; | |||||
| } | |||||
| finished_cond_var_.notify_one(); | |||||
| } | |||||
| } | |||||
| bool ThreadPool::SyncRun(const std::vector<Task> &tasks) { | bool ThreadPool::SyncRun(const std::vector<Task> &tasks) { | ||||
| if (tasks.size() == 1) { | |||||
| auto ret = tasks[0](); | |||||
| return ret == SUCCESS; | |||||
| } | |||||
| std::unique_lock<std::mutex> lock(pool_mtx_); | |||||
| exit_run_ = false; | |||||
| int task_num = tasks.size(); | |||||
| int thread_num = sync_run_threads_.size(); | |||||
| if (thread_num < max_thread_num_ && thread_num < task_num) { | |||||
| auto new_thread_num = max_thread_num_; | |||||
| if (task_num < max_thread_num_) { | |||||
| new_thread_num = task_num; | |||||
| } | |||||
| for (int i = thread_num; i < new_thread_num; ++i) { | |||||
| sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this)); | |||||
| } | |||||
| } | |||||
| for (auto &task : tasks) { | |||||
| std::lock_guard<std::mutex> task_lock(task_mutex_); | |||||
| task_queue_.push(task); | |||||
| task_cond_var_.notify_one(); | |||||
| } | |||||
| { | |||||
| std::unique_lock<std::mutex> task_lock(task_mutex_); | |||||
| finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; }); | |||||
| task_finished_count_ = 0; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ThreadPool::InnerSyncRun(const std::vector<Task> &tasks) { | |||||
| std::lock_guard<std::mutex> sync_run_lock(pool_mtx_); | |||||
| int thread_num = tasks.size(); | int thread_num = tasks.size(); | ||||
| if (thread_num > max_thread_num_) { | if (thread_num > max_thread_num_) { | ||||
| thread_num = max_thread_num_; | thread_num = max_thread_num_; | ||||
| @@ -196,19 +258,34 @@ ThreadPool &ThreadPool::GetInstance() { | |||||
| return instance; | return instance; | ||||
| } | } | ||||
| ThreadPool::~ThreadPool() { | |||||
| void ThreadPool::ClearThreadPool() { | |||||
| std::lock_guard<std::mutex> sync_run_lock(pool_mtx_); | |||||
| if (exit_run_) { | |||||
| return; | |||||
| } | |||||
| exit_run_ = true; | exit_run_ = true; | ||||
| cur_thread_run_nums_ = static_cast<int>(thread_list_.size()); | cur_thread_run_nums_ = static_cast<int>(thread_list_.size()); | ||||
| SubRunThread(0); | SubRunThread(0); | ||||
| queue_ready_.notify_all(); | queue_ready_.notify_all(); | ||||
| task_cond_var_.notify_all(); | |||||
| for (auto &it : sync_run_threads_) { | |||||
| if (it.joinable()) { | |||||
| it.join(); | |||||
| } | |||||
| } | |||||
| sync_run_threads_.clear(); | |||||
| for (auto &it : thread_list_) { | for (auto &it : thread_list_) { | ||||
| if (it.joinable()) { | if (it.joinable()) { | ||||
| it.join(); | it.join(); | ||||
| } | } | ||||
| } | } | ||||
| thread_list_.clear(); | |||||
| for (const auto &it : activate_list_) { | for (const auto &it : activate_list_) { | ||||
| delete it; | delete it; | ||||
| } | } | ||||
| activate_list_.clear(); | |||||
| } | } | ||||
| ThreadPool::~ThreadPool() { ClearThreadPool(); } | |||||
| } // namespace common | } // namespace common | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -56,12 +56,10 @@ class ThreadPool { | |||||
| ~ThreadPool(); | ~ThreadPool(); | ||||
| ThreadPool(const ThreadPool &) = delete; | ThreadPool(const ThreadPool &) = delete; | ||||
| ThreadPool &operator=(const ThreadPool &) = delete; | ThreadPool &operator=(const ThreadPool &) = delete; | ||||
| static ThreadPool &GetInstance(); | static ThreadPool &GetInstance(); | ||||
| // Use the tasks' size of threads to execute these tasks, one thread execute one task. | |||||
| bool SyncRun(const std::vector<Task> &tasks); | bool SyncRun(const std::vector<Task> &tasks); | ||||
| size_t GetSyncRunThreadNum() { return max_thread_num_; } | size_t GetSyncRunThreadNum() { return max_thread_num_; } | ||||
| void ClearThreadPool(); | |||||
| private: | private: | ||||
| ThreadPool(); | ThreadPool(); | ||||
| @@ -70,6 +68,8 @@ class ThreadPool { | |||||
| void AddRunThread(int num); | void AddRunThread(int num); | ||||
| void SubRunThread(int num); | void SubRunThread(int num); | ||||
| bool CheckResult(); | bool CheckResult(); | ||||
| bool InnerSyncRun(const std::vector<Task> &tasks); | |||||
| void SyncRunLoop(); | |||||
| int cur_thread_nums_{0}; | int cur_thread_nums_{0}; | ||||
| int cur_thread_run_nums_{0}; | int cur_thread_run_nums_{0}; | ||||
| @@ -83,6 +83,12 @@ class ThreadPool { | |||||
| std::vector<std::thread> thread_list_{}; | std::vector<std::thread> thread_list_{}; | ||||
| std::vector<std::shared_ptr<Queue>> queue_list_{}; | std::vector<std::shared_ptr<Queue>> queue_list_{}; | ||||
| std::vector<std::pair<int, std::pair<bool, int>>> error_info_{}; | std::vector<std::pair<int, std::pair<bool, int>>> error_info_{}; | ||||
| std::queue<Task> task_queue_; | |||||
| std::mutex task_mutex_; | |||||
| std::condition_variable task_cond_var_; | |||||
| int task_finished_count_{0}; | |||||
| std::condition_variable finished_cond_var_; | |||||
| std::vector<std::thread> sync_run_threads_{}; | |||||
| }; | }; | ||||
| } // namespace common | } // namespace common | ||||
| } // namespace mindspore | } // namespace mindspore | ||||