diff --git a/mindspore/ccsrc/common/thread_pool.cc b/mindspore/ccsrc/common/thread_pool.cc index 23cd483ab1..62cb3708fa 100644 --- a/mindspore/ccsrc/common/thread_pool.cc +++ b/mindspore/ccsrc/common/thread_pool.cc @@ -24,38 +24,12 @@ namespace mindspore { namespace common { #ifdef ENABLE_D -const int kDeviceNum = 8; +const size_t kDeviceNum = 8; #endif -const int kMaxThreadNum = 23; -bool Queue::Enqueue(Task *task) { - const int tail_index = tail_.load(std::memory_order_relaxed); - // queue full - auto next = (tail_index + 1) % 2; - if (next == head_.load(std::memory_order_acquire)) { - return false; - } - buffer_[tail_index] = task; - tail_.store(next, std::memory_order_release); - ++task_size_; - return true; -} - -bool Queue::Dequeue(Task **out) { - if (task_size_ == 0) { - return false; - } - // queue empty - const int head_index = head_.load(std::memory_order_relaxed); - if (head_index == tail_.load(std::memory_order_acquire)) { - return false; - } - *out = buffer_[head_index]; - head_.store((head_index + 1) % 2, std::memory_order_release); - return true; -} +const size_t kMaxThreadNum = 23; ThreadPool::ThreadPool() { - int process_core_num = std::thread::hardware_concurrency() - 1; + size_t process_core_num = std::thread::hardware_concurrency() - 1; if (process_core_num < 1) { process_core_num = 1; } @@ -72,80 +46,6 @@ ThreadPool::ThreadPool() { } } -bool ThreadPool::SetThreadPool(int config_thread_num) { - if (config_thread_num > max_thread_num_) { - MS_LOG(EXCEPTION) << "Expected thread num is greater than the max thread num, expected thread num=" - << config_thread_num << ", allowed max thread num=" << max_thread_num_; - } - if (config_thread_num > cur_thread_nums_) { - AddNewThread(config_thread_num - cur_thread_nums_); - } - MS_LOG(DEBUG) << "cur_thread_nums_=" << cur_thread_nums_ << ", cur_thread_run_nums_=" << cur_thread_run_nums_; - return true; -} - -void ThreadPool::AddNewThread(int add_num) { - for (int i = cur_thread_nums_, j = 0; j < add_num; ++i, ++j) { - auto active = new std::atomic_bool{true}; - auto queue = std::make_shared(); - std::thread thread([this, i, active, queue]() { - Task *task = nullptr; - while (!exit_run_) { - while (*active) { - if (queue->Dequeue(&task)) { - int ret; - try { - ret = (*task)(); - } catch (std::exception &e) { - ret = FAIL; - MsException::Instance().SetException(); - } - if (ret != SUCCESS) { - error_info_.emplace_back(std::make_pair(i, std::make_pair(false, ret))); - } - queue->task_size_--; - } - std::this_thread::yield(); - } - std::unique_lock queue_lock(thread_mtx_); - queue_ready_.wait(queue_lock, [active, this] { return exit_run_ || *active; }); - } - }); - thread_list_.emplace_back(std::move(thread)); - activate_list_.emplace_back(active); - queue_list_.emplace_back(queue); - } - cur_thread_nums_ += add_num; - cur_thread_run_nums_ += add_num; - MS_LOG(INFO) << "add " << add_num << " thread"; -} - -void ThreadPool::AddRunThread(int num) { - MS_LOG(DEBUG) << "num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_; - int active_nums = num - cur_thread_run_nums_; - if (active_nums <= 0 || static_cast(activate_list_.size()) < active_nums) { - return; - } - for (int i = cur_thread_run_nums_ - 1, j = 0; j < active_nums; ++i, ++j) { - *activate_list_[i] = true; - } - std::lock_guard queueLock(thread_mtx_); - queue_ready_.notify_all(); - cur_thread_run_nums_ = num; -} - -void ThreadPool::SubRunThread(int num) { - MS_LOG(DEBUG) << "sub num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_; - int deactive_nums = cur_thread_run_nums_ - num; - if (deactive_nums <= 0) { - return; - } - for (int i = num, j = 0; j < deactive_nums; ++i, ++j) { - *activate_list_[i] = false; - } - cur_thread_run_nums_ = num; -} - void ThreadPool::SyncRunLoop() { while (true) { Task task; @@ -178,14 +78,14 @@ bool ThreadPool::SyncRun(const std::vector &tasks) { } std::unique_lock lock(pool_mtx_); exit_run_ = false; - int task_num = tasks.size(); - int thread_num = sync_run_threads_.size(); + size_t task_num = tasks.size(); + size_t thread_num = sync_run_threads_.size(); if (thread_num < max_thread_num_ && thread_num < task_num) { auto new_thread_num = max_thread_num_; if (task_num < max_thread_num_) { new_thread_num = task_num; } - for (int i = thread_num; i < new_thread_num; ++i) { + for (size_t i = thread_num; i < new_thread_num; ++i) { sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this)); } } @@ -203,56 +103,6 @@ bool ThreadPool::SyncRun(const std::vector &tasks) { return true; } -bool ThreadPool::InnerSyncRun(const std::vector &tasks) { - std::lock_guard sync_run_lock(pool_mtx_); - int thread_num = tasks.size(); - if (thread_num > max_thread_num_) { - thread_num = max_thread_num_; - } - if (!SetThreadPool(thread_num)) { - return false; - } - error_info_.clear(); - bool succ_flag; - for (int task_id = 0, queue_index = 0; task_id < SizeToInt(tasks.size()); ++task_id) { - do { - succ_flag = true; - if (!queue_list_[queue_index]->Enqueue(const_cast(&tasks[task_id]))) { - std::this_thread::yield(); - succ_flag = false; - } - } while (!succ_flag); - queue_index++; - if (queue_index >= cur_thread_run_nums_) { - queue_index = queue_index - cur_thread_run_nums_; - } - } - succ_flag = false; - while (!succ_flag) { - std::this_thread::yield(); - succ_flag = true; - for (int i = 0; i < cur_thread_run_nums_; ++i) { - if (queue_list_[i]->task_size_ != 0) { - succ_flag = false; - break; - } - } - } - MS_LOG(INFO) << "Finish " << tasks.size() << " task successful"; - return CheckResult(); -} - -bool ThreadPool::CheckResult() { - bool succ_flag = true; - for (auto result : error_info_) { - if (result.second.first) { - MS_LOG(ERROR) << "task " << result.first << " failed, error code is " << result.second.second; - succ_flag = false; - } - } - return succ_flag; -} - ThreadPool &ThreadPool::GetInstance() { static ThreadPool instance; return instance; @@ -264,9 +114,6 @@ void ThreadPool::ClearThreadPool() { return; } exit_run_ = true; - cur_thread_run_nums_ = static_cast(thread_list_.size()); - SubRunThread(0); - queue_ready_.notify_all(); task_cond_var_.notify_all(); for (auto &it : sync_run_threads_) { if (it.joinable()) { @@ -274,16 +121,6 @@ void ThreadPool::ClearThreadPool() { } } sync_run_threads_.clear(); - for (auto &it : thread_list_) { - if (it.joinable()) { - it.join(); - } - } - thread_list_.clear(); - for (const auto &it : activate_list_) { - delete it; - } - activate_list_.clear(); } ThreadPool::~ThreadPool() { ClearThreadPool(); } diff --git a/mindspore/ccsrc/common/thread_pool.h b/mindspore/ccsrc/common/thread_pool.h index 14e931fd06..6a6aa71791 100644 --- a/mindspore/ccsrc/common/thread_pool.h +++ b/mindspore/ccsrc/common/thread_pool.h @@ -32,25 +32,9 @@ namespace mindspore { namespace common { -const int kCoreThreadNum = 3; -const int kDefaultMaxThreadNum = 8; enum Status { FAIL = -1, SUCCESS = 0 }; using Task = std::function; -class Queue { - public: - Queue() = default; - ~Queue() = default; - bool Enqueue(Task *task); - bool Dequeue(Task **out); - std::atomic_int task_size_ = {0}; - - private: - std::atomic_int head_ = {0}; - std::atomic_int tail_ = {0}; - Task *buffer_[2]{}; -}; - class ThreadPool { public: ~ThreadPool(); @@ -63,30 +47,15 @@ class ThreadPool { private: ThreadPool(); - bool SetThreadPool(int config_thread_num); - void AddNewThread(int add_num); - void AddRunThread(int num); - void SubRunThread(int num); - bool CheckResult(); - bool InnerSyncRun(const std::vector &tasks); void SyncRunLoop(); - int cur_thread_nums_{0}; - int cur_thread_run_nums_{0}; - int core_thread_num_{kCoreThreadNum}; - int max_thread_num_{kDefaultMaxThreadNum}; + size_t max_thread_num_{1}; std::mutex pool_mtx_; - std::mutex thread_mtx_; - std::condition_variable queue_ready_; std::atomic_bool exit_run_ = {false}; - std::vector activate_list_{}; - std::vector thread_list_{}; - std::vector> queue_list_{}; - std::vector>> error_info_{}; std::queue task_queue_; std::mutex task_mutex_; std::condition_variable task_cond_var_; - int task_finished_count_{0}; + size_t task_finished_count_{0}; std::condition_variable finished_cond_var_; std::vector sync_run_threads_{}; };