From 9a4757984295cf3c561fdf08140578601ba2e05c Mon Sep 17 00:00:00 2001 From: kswang Date: Tue, 22 Dec 2020 19:45:10 +0800 Subject: [PATCH] fix thread pool performance problem --- mindspore/ccsrc/backend/session/executor.cc | 21 +++-- .../ccsrc/backend/session/executor_manager.cc | 3 +- mindspore/ccsrc/common/thread_pool.cc | 85 ++++++++++++++++++- mindspore/ccsrc/common/thread_pool.h | 12 ++- 4 files changed, 108 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 3998b6c9fe..2b80bd5c81 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -122,13 +122,13 @@ void RunGraphTask::Run() { ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); MsException::Instance().SetException(); } + MS_LOG(INFO) << "End run graph " << graph_id_; graph->OnRunGraphFinished(); for (auto &tensor : input_need_lock_tensors_) { tensor->SetNeedWait(false); } NotifyOutputTensors(&outputs_); ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished); - MS_LOG(INFO) << "End run graph " << graph_id_; } void RunOpTask::Run() { @@ -218,11 +218,22 @@ std::vector> Executor::GetNewReadyTasks() { void Executor::OnEvent(const ExecutorEvent &event) { if (event == ExecutorEvent::kRunGraphFinished) { OnRunGraphFinished(); + } else if (event == ExecutorEvent::kClear) { + WorkerJoin(); } else if (event == ExecutorEvent::kException) { - std::unique_lock lock(task_mutex_); - while (!ready_tasks_.empty()) { - done_tasks_.emplace_back(ready_tasks_.front()); - ready_tasks_.pop(); + { + std::unique_lock lock(task_mutex_); + while (!ready_tasks_.empty()) { + done_tasks_.emplace_back(ready_tasks_.front()); + ready_tasks_.pop(); + } + } + { + std::unique_lock lock(pending_task_mutex_); + for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) { + done_tasks_.emplace_back(*iter); + } + pending_tasks_.clear(); } } } diff --git a/mindspore/ccsrc/backend/session/executor_manager.cc b/mindspore/ccsrc/backend/session/executor_manager.cc index 46d34795bf..d4e01a2bdd 100644 --- a/mindspore/ccsrc/backend/session/executor_manager.cc +++ b/mindspore/ccsrc/backend/session/executor_manager.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "backend/session/executor_manager.h" - +#include "common/thread_pool.h" namespace mindspore { namespace session { std::shared_ptr ExecutorManager::GetExecutor(const std::string &device_name, int device_id) { @@ -40,6 +40,7 @@ void ExecutorManager::OnEvent(const ExecutorEvent &event) { void ExecutorManager::Clear() { OnEvent(ExecutorEvent::kClear); executors_.clear(); + common::ThreadPool::GetInstance().ClearThreadPool(); } } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/common/thread_pool.cc b/mindspore/ccsrc/common/thread_pool.cc index d6b3101019..23cd483ab1 100644 --- a/mindspore/ccsrc/common/thread_pool.cc +++ b/mindspore/ccsrc/common/thread_pool.cc @@ -26,7 +26,7 @@ namespace common { #ifdef ENABLE_D const int kDeviceNum = 8; #endif - +const int kMaxThreadNum = 23; bool Queue::Enqueue(Task *task) { const int tail_index = tail_.load(std::memory_order_relaxed); // queue full @@ -64,11 +64,15 @@ ThreadPool::ThreadPool() { #else max_thread_num_ = process_core_num; #endif - SetThreadPool(core_thread_num_); + if (max_thread_num_ < 1) { + max_thread_num_ = 1; + } + if (max_thread_num_ > kMaxThreadNum) { + max_thread_num_ = kMaxThreadNum; + } } bool ThreadPool::SetThreadPool(int config_thread_num) { - std::lock_guard Lock(pool_mtx_); if (config_thread_num > max_thread_num_) { MS_LOG(EXCEPTION) << "Expected thread num is greater than the max thread num, expected thread num=" << config_thread_num << ", allowed max thread num=" << max_thread_num_; @@ -142,7 +146,65 @@ void ThreadPool::SubRunThread(int num) { cur_thread_run_nums_ = num; } +void ThreadPool::SyncRunLoop() { + while (true) { + Task task; + { + std::unique_lock lock(task_mutex_); + task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; }); + if (exit_run_) { + return; + } + task = task_queue_.front(); + task_queue_.pop(); + } + try { + task(); + } catch (std::exception &e) { + MsException::Instance().SetException(); + } + { + std::unique_lock task_lock(task_mutex_); + task_finished_count_ = task_finished_count_ + 1; + } + finished_cond_var_.notify_one(); + } +} + bool ThreadPool::SyncRun(const std::vector &tasks) { + if (tasks.size() == 1) { + auto ret = tasks[0](); + return ret == SUCCESS; + } + std::unique_lock lock(pool_mtx_); + exit_run_ = false; + int task_num = tasks.size(); + int thread_num = sync_run_threads_.size(); + if (thread_num < max_thread_num_ && thread_num < task_num) { + auto new_thread_num = max_thread_num_; + if (task_num < max_thread_num_) { + new_thread_num = task_num; + } + for (int i = thread_num; i < new_thread_num; ++i) { + sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this)); + } + } + + for (auto &task : tasks) { + std::lock_guard task_lock(task_mutex_); + task_queue_.push(task); + task_cond_var_.notify_one(); + } + { + std::unique_lock task_lock(task_mutex_); + finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; }); + task_finished_count_ = 0; + } + return true; +} + +bool ThreadPool::InnerSyncRun(const std::vector &tasks) { + std::lock_guard sync_run_lock(pool_mtx_); int thread_num = tasks.size(); if (thread_num > max_thread_num_) { thread_num = max_thread_num_; @@ -196,19 +258,34 @@ ThreadPool &ThreadPool::GetInstance() { return instance; } -ThreadPool::~ThreadPool() { +void ThreadPool::ClearThreadPool() { + std::lock_guard sync_run_lock(pool_mtx_); + if (exit_run_) { + return; + } exit_run_ = true; cur_thread_run_nums_ = static_cast(thread_list_.size()); SubRunThread(0); queue_ready_.notify_all(); + task_cond_var_.notify_all(); + for (auto &it : sync_run_threads_) { + if (it.joinable()) { + it.join(); + } + } + sync_run_threads_.clear(); for (auto &it : thread_list_) { if (it.joinable()) { it.join(); } } + thread_list_.clear(); for (const auto &it : activate_list_) { delete it; } + activate_list_.clear(); } + +ThreadPool::~ThreadPool() { ClearThreadPool(); } } // namespace common } // namespace mindspore diff --git a/mindspore/ccsrc/common/thread_pool.h b/mindspore/ccsrc/common/thread_pool.h index 3ea3ca43f6..14e931fd06 100644 --- a/mindspore/ccsrc/common/thread_pool.h +++ b/mindspore/ccsrc/common/thread_pool.h @@ -56,12 +56,10 @@ class ThreadPool { ~ThreadPool(); ThreadPool(const ThreadPool &) = delete; ThreadPool &operator=(const ThreadPool &) = delete; - static ThreadPool &GetInstance(); - // Use the tasks' size of threads to execute these tasks, one thread execute one task. bool SyncRun(const std::vector &tasks); - size_t GetSyncRunThreadNum() { return max_thread_num_; } + void ClearThreadPool(); private: ThreadPool(); @@ -70,6 +68,8 @@ class ThreadPool { void AddRunThread(int num); void SubRunThread(int num); bool CheckResult(); + bool InnerSyncRun(const std::vector &tasks); + void SyncRunLoop(); int cur_thread_nums_{0}; int cur_thread_run_nums_{0}; @@ -83,6 +83,12 @@ class ThreadPool { std::vector thread_list_{}; std::vector> queue_list_{}; std::vector>> error_info_{}; + std::queue task_queue_; + std::mutex task_mutex_; + std::condition_variable task_cond_var_; + int task_finished_count_{0}; + std::condition_variable finished_cond_var_; + std::vector sync_run_threads_{}; }; } // namespace common } // namespace mindspore