|
|
|
@@ -26,7 +26,7 @@ namespace common { |
|
|
|
#ifdef ENABLE_D |
|
|
|
const int kDeviceNum = 8; |
|
|
|
#endif |
|
|
|
|
|
|
|
const int kMaxThreadNum = 23; |
|
|
|
bool Queue::Enqueue(Task *task) { |
|
|
|
const int tail_index = tail_.load(std::memory_order_relaxed); |
|
|
|
// queue full |
|
|
|
@@ -64,11 +64,15 @@ ThreadPool::ThreadPool() { |
|
|
|
#else |
|
|
|
max_thread_num_ = process_core_num; |
|
|
|
#endif |
|
|
|
SetThreadPool(core_thread_num_); |
|
|
|
if (max_thread_num_ < 1) { |
|
|
|
max_thread_num_ = 1; |
|
|
|
} |
|
|
|
if (max_thread_num_ > kMaxThreadNum) { |
|
|
|
max_thread_num_ = kMaxThreadNum; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool ThreadPool::SetThreadPool(int config_thread_num) { |
|
|
|
std::lock_guard<std::mutex> Lock(pool_mtx_); |
|
|
|
if (config_thread_num > max_thread_num_) { |
|
|
|
MS_LOG(EXCEPTION) << "Expected thread num is greater than the max thread num, expected thread num=" |
|
|
|
<< config_thread_num << ", allowed max thread num=" << max_thread_num_; |
|
|
|
@@ -142,7 +146,65 @@ void ThreadPool::SubRunThread(int num) { |
|
|
|
cur_thread_run_nums_ = num; |
|
|
|
} |
|
|
|
|
|
|
|
void ThreadPool::SyncRunLoop() { |
|
|
|
while (true) { |
|
|
|
Task task; |
|
|
|
{ |
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_); |
|
|
|
task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; }); |
|
|
|
if (exit_run_) { |
|
|
|
return; |
|
|
|
} |
|
|
|
task = task_queue_.front(); |
|
|
|
task_queue_.pop(); |
|
|
|
} |
|
|
|
try { |
|
|
|
task(); |
|
|
|
} catch (std::exception &e) { |
|
|
|
MsException::Instance().SetException(); |
|
|
|
} |
|
|
|
{ |
|
|
|
std::unique_lock<std::mutex> task_lock(task_mutex_); |
|
|
|
task_finished_count_ = task_finished_count_ + 1; |
|
|
|
} |
|
|
|
finished_cond_var_.notify_one(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool ThreadPool::SyncRun(const std::vector<Task> &tasks) { |
|
|
|
if (tasks.size() == 1) { |
|
|
|
auto ret = tasks[0](); |
|
|
|
return ret == SUCCESS; |
|
|
|
} |
|
|
|
std::unique_lock<std::mutex> lock(pool_mtx_); |
|
|
|
exit_run_ = false; |
|
|
|
int task_num = tasks.size(); |
|
|
|
int thread_num = sync_run_threads_.size(); |
|
|
|
if (thread_num < max_thread_num_ && thread_num < task_num) { |
|
|
|
auto new_thread_num = max_thread_num_; |
|
|
|
if (task_num < max_thread_num_) { |
|
|
|
new_thread_num = task_num; |
|
|
|
} |
|
|
|
for (int i = thread_num; i < new_thread_num; ++i) { |
|
|
|
sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &task : tasks) { |
|
|
|
std::lock_guard<std::mutex> task_lock(task_mutex_); |
|
|
|
task_queue_.push(task); |
|
|
|
task_cond_var_.notify_one(); |
|
|
|
} |
|
|
|
{ |
|
|
|
std::unique_lock<std::mutex> task_lock(task_mutex_); |
|
|
|
finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; }); |
|
|
|
task_finished_count_ = 0; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool ThreadPool::InnerSyncRun(const std::vector<Task> &tasks) { |
|
|
|
std::lock_guard<std::mutex> sync_run_lock(pool_mtx_); |
|
|
|
int thread_num = tasks.size(); |
|
|
|
if (thread_num > max_thread_num_) { |
|
|
|
thread_num = max_thread_num_; |
|
|
|
@@ -196,19 +258,34 @@ ThreadPool &ThreadPool::GetInstance() { |
|
|
|
return instance; |
|
|
|
} |
|
|
|
|
|
|
|
ThreadPool::~ThreadPool() { |
|
|
|
void ThreadPool::ClearThreadPool() { |
|
|
|
std::lock_guard<std::mutex> sync_run_lock(pool_mtx_); |
|
|
|
if (exit_run_) { |
|
|
|
return; |
|
|
|
} |
|
|
|
exit_run_ = true; |
|
|
|
cur_thread_run_nums_ = static_cast<int>(thread_list_.size()); |
|
|
|
SubRunThread(0); |
|
|
|
queue_ready_.notify_all(); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
for (auto &it : sync_run_threads_) { |
|
|
|
if (it.joinable()) { |
|
|
|
it.join(); |
|
|
|
} |
|
|
|
} |
|
|
|
sync_run_threads_.clear(); |
|
|
|
for (auto &it : thread_list_) { |
|
|
|
if (it.joinable()) { |
|
|
|
it.join(); |
|
|
|
} |
|
|
|
} |
|
|
|
thread_list_.clear(); |
|
|
|
for (const auto &it : activate_list_) { |
|
|
|
delete it; |
|
|
|
} |
|
|
|
activate_list_.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
ThreadPool::~ThreadPool() { ClearThreadPool(); } |
|
|
|
} // namespace common |
|
|
|
} // namespace mindspore |