/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/thread_pool.h" #include #include #include "utils/log_adapter.h" #include "utils/convert_utils_base.h" #include "utils/ms_exception.h" namespace mindspore { namespace common { #ifdef ENABLE_D const size_t kDeviceNum = 8; #endif const size_t kMaxThreadNum = 23; ThreadPool::ThreadPool() { size_t process_core_num = std::thread::hardware_concurrency() - 1; if (process_core_num < 1) { process_core_num = 1; } #ifdef ENABLE_D max_thread_num_ = process_core_num / kDeviceNum; #else max_thread_num_ = process_core_num; #endif if (max_thread_num_ < 1) { max_thread_num_ = 1; } if (max_thread_num_ > kMaxThreadNum) { max_thread_num_ = kMaxThreadNum; } } void ThreadPool::SyncRunLoop() { while (true) { Task task; { std::unique_lock lock(task_mutex_); task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; }); if (exit_run_) { return; } task = task_queue_.front(); task_queue_.pop(); } try { task(); } catch (std::exception &e) { MsException::Instance().SetException(); } { std::unique_lock task_lock(task_mutex_); task_finished_count_ = task_finished_count_ + 1; } finished_cond_var_.notify_one(); } } bool ThreadPool::SyncRun(const std::vector &tasks) { if (tasks.size() == 1) { auto ret = tasks[0](); return ret == SUCCESS; } std::unique_lock lock(pool_mtx_); exit_run_ = false; size_t task_num = tasks.size(); size_t thread_num = sync_run_threads_.size(); if (thread_num < max_thread_num_ && thread_num < task_num) { auto new_thread_num = max_thread_num_; if (task_num < max_thread_num_) { new_thread_num = task_num; } for (size_t i = thread_num; i < new_thread_num; ++i) { sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this)); } } for (auto &task : tasks) { std::lock_guard task_lock(task_mutex_); task_queue_.push(task); task_cond_var_.notify_one(); } { std::unique_lock task_lock(task_mutex_); finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; }); task_finished_count_ = 0; } return true; } ThreadPool &ThreadPool::GetInstance() { static ThreadPool instance; return instance; } void ThreadPool::ClearThreadPool() { std::lock_guard sync_run_lock(pool_mtx_); if (exit_run_) { return; } exit_run_ = true; task_cond_var_.notify_all(); for (auto &it : sync_run_threads_) { if (it.joinable()) { it.join(); } } sync_run_threads_.clear(); } ThreadPool::~ThreadPool() { ClearThreadPool(); } } // namespace common } // namespace mindspore