|
|
|
@@ -24,38 +24,12 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace common { |
|
|
|
#ifdef ENABLE_D |
|
|
|
const int kDeviceNum = 8; |
|
|
|
const size_t kDeviceNum = 8; |
|
|
|
#endif |
|
|
|
const int kMaxThreadNum = 23; |
|
|
|
bool Queue::Enqueue(Task *task) { |
|
|
|
const int tail_index = tail_.load(std::memory_order_relaxed); |
|
|
|
// queue full |
|
|
|
auto next = (tail_index + 1) % 2; |
|
|
|
if (next == head_.load(std::memory_order_acquire)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
buffer_[tail_index] = task; |
|
|
|
tail_.store(next, std::memory_order_release); |
|
|
|
++task_size_; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool Queue::Dequeue(Task **out) { |
|
|
|
if (task_size_ == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
// queue empty |
|
|
|
const int head_index = head_.load(std::memory_order_relaxed); |
|
|
|
if (head_index == tail_.load(std::memory_order_acquire)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
*out = buffer_[head_index]; |
|
|
|
head_.store((head_index + 1) % 2, std::memory_order_release); |
|
|
|
return true; |
|
|
|
} |
|
|
|
const size_t kMaxThreadNum = 23; |
|
|
|
|
|
|
|
ThreadPool::ThreadPool() { |
|
|
|
int process_core_num = std::thread::hardware_concurrency() - 1; |
|
|
|
size_t process_core_num = std::thread::hardware_concurrency() - 1; |
|
|
|
if (process_core_num < 1) { |
|
|
|
process_core_num = 1; |
|
|
|
} |
|
|
|
@@ -72,80 +46,6 @@ ThreadPool::ThreadPool() { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool ThreadPool::SetThreadPool(int config_thread_num) { |
|
|
|
if (config_thread_num > max_thread_num_) { |
|
|
|
MS_LOG(EXCEPTION) << "Expected thread num is greater than the max thread num, expected thread num=" |
|
|
|
<< config_thread_num << ", allowed max thread num=" << max_thread_num_; |
|
|
|
} |
|
|
|
if (config_thread_num > cur_thread_nums_) { |
|
|
|
AddNewThread(config_thread_num - cur_thread_nums_); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "cur_thread_nums_=" << cur_thread_nums_ << ", cur_thread_run_nums_=" << cur_thread_run_nums_; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void ThreadPool::AddNewThread(int add_num) { |
|
|
|
for (int i = cur_thread_nums_, j = 0; j < add_num; ++i, ++j) { |
|
|
|
auto active = new std::atomic_bool{true}; |
|
|
|
auto queue = std::make_shared<Queue>(); |
|
|
|
std::thread thread([this, i, active, queue]() { |
|
|
|
Task *task = nullptr; |
|
|
|
while (!exit_run_) { |
|
|
|
while (*active) { |
|
|
|
if (queue->Dequeue(&task)) { |
|
|
|
int ret; |
|
|
|
try { |
|
|
|
ret = (*task)(); |
|
|
|
} catch (std::exception &e) { |
|
|
|
ret = FAIL; |
|
|
|
MsException::Instance().SetException(); |
|
|
|
} |
|
|
|
if (ret != SUCCESS) { |
|
|
|
error_info_.emplace_back(std::make_pair(i, std::make_pair(false, ret))); |
|
|
|
} |
|
|
|
queue->task_size_--; |
|
|
|
} |
|
|
|
std::this_thread::yield(); |
|
|
|
} |
|
|
|
std::unique_lock<std::mutex> queue_lock(thread_mtx_); |
|
|
|
queue_ready_.wait(queue_lock, [active, this] { return exit_run_ || *active; }); |
|
|
|
} |
|
|
|
}); |
|
|
|
thread_list_.emplace_back(std::move(thread)); |
|
|
|
activate_list_.emplace_back(active); |
|
|
|
queue_list_.emplace_back(queue); |
|
|
|
} |
|
|
|
cur_thread_nums_ += add_num; |
|
|
|
cur_thread_run_nums_ += add_num; |
|
|
|
MS_LOG(INFO) << "add " << add_num << " thread"; |
|
|
|
} |
|
|
|
|
|
|
|
void ThreadPool::AddRunThread(int num) { |
|
|
|
MS_LOG(DEBUG) << "num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_; |
|
|
|
int active_nums = num - cur_thread_run_nums_; |
|
|
|
if (active_nums <= 0 || static_cast<int>(activate_list_.size()) < active_nums) { |
|
|
|
return; |
|
|
|
} |
|
|
|
for (int i = cur_thread_run_nums_ - 1, j = 0; j < active_nums; ++i, ++j) { |
|
|
|
*activate_list_[i] = true; |
|
|
|
} |
|
|
|
std::lock_guard<std::mutex> queueLock(thread_mtx_); |
|
|
|
queue_ready_.notify_all(); |
|
|
|
cur_thread_run_nums_ = num; |
|
|
|
} |
|
|
|
|
|
|
|
void ThreadPool::SubRunThread(int num) { |
|
|
|
MS_LOG(DEBUG) << "sub num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_; |
|
|
|
int deactive_nums = cur_thread_run_nums_ - num; |
|
|
|
if (deactive_nums <= 0) { |
|
|
|
return; |
|
|
|
} |
|
|
|
for (int i = num, j = 0; j < deactive_nums; ++i, ++j) { |
|
|
|
*activate_list_[i] = false; |
|
|
|
} |
|
|
|
cur_thread_run_nums_ = num; |
|
|
|
} |
|
|
|
|
|
|
|
void ThreadPool::SyncRunLoop() { |
|
|
|
while (true) { |
|
|
|
Task task; |
|
|
|
@@ -178,14 +78,14 @@ bool ThreadPool::SyncRun(const std::vector<Task> &tasks) { |
|
|
|
} |
|
|
|
std::unique_lock<std::mutex> lock(pool_mtx_); |
|
|
|
exit_run_ = false; |
|
|
|
int task_num = tasks.size(); |
|
|
|
int thread_num = sync_run_threads_.size(); |
|
|
|
size_t task_num = tasks.size(); |
|
|
|
size_t thread_num = sync_run_threads_.size(); |
|
|
|
if (thread_num < max_thread_num_ && thread_num < task_num) { |
|
|
|
auto new_thread_num = max_thread_num_; |
|
|
|
if (task_num < max_thread_num_) { |
|
|
|
new_thread_num = task_num; |
|
|
|
} |
|
|
|
for (int i = thread_num; i < new_thread_num; ++i) { |
|
|
|
for (size_t i = thread_num; i < new_thread_num; ++i) { |
|
|
|
sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this)); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -203,56 +103,6 @@ bool ThreadPool::SyncRun(const std::vector<Task> &tasks) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool ThreadPool::InnerSyncRun(const std::vector<Task> &tasks) { |
|
|
|
std::lock_guard<std::mutex> sync_run_lock(pool_mtx_); |
|
|
|
int thread_num = tasks.size(); |
|
|
|
if (thread_num > max_thread_num_) { |
|
|
|
thread_num = max_thread_num_; |
|
|
|
} |
|
|
|
if (!SetThreadPool(thread_num)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
error_info_.clear(); |
|
|
|
bool succ_flag; |
|
|
|
for (int task_id = 0, queue_index = 0; task_id < SizeToInt(tasks.size()); ++task_id) { |
|
|
|
do { |
|
|
|
succ_flag = true; |
|
|
|
if (!queue_list_[queue_index]->Enqueue(const_cast<Task *>(&tasks[task_id]))) { |
|
|
|
std::this_thread::yield(); |
|
|
|
succ_flag = false; |
|
|
|
} |
|
|
|
} while (!succ_flag); |
|
|
|
queue_index++; |
|
|
|
if (queue_index >= cur_thread_run_nums_) { |
|
|
|
queue_index = queue_index - cur_thread_run_nums_; |
|
|
|
} |
|
|
|
} |
|
|
|
succ_flag = false; |
|
|
|
while (!succ_flag) { |
|
|
|
std::this_thread::yield(); |
|
|
|
succ_flag = true; |
|
|
|
for (int i = 0; i < cur_thread_run_nums_; ++i) { |
|
|
|
if (queue_list_[i]->task_size_ != 0) { |
|
|
|
succ_flag = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Finish " << tasks.size() << " task successful"; |
|
|
|
return CheckResult(); |
|
|
|
} |
|
|
|
|
|
|
|
bool ThreadPool::CheckResult() { |
|
|
|
bool succ_flag = true; |
|
|
|
for (auto result : error_info_) { |
|
|
|
if (result.second.first) { |
|
|
|
MS_LOG(ERROR) << "task " << result.first << " failed, error code is " << result.second.second; |
|
|
|
succ_flag = false; |
|
|
|
} |
|
|
|
} |
|
|
|
return succ_flag; |
|
|
|
} |
|
|
|
|
|
|
|
ThreadPool &ThreadPool::GetInstance() { |
|
|
|
static ThreadPool instance; |
|
|
|
return instance; |
|
|
|
@@ -264,9 +114,6 @@ void ThreadPool::ClearThreadPool() { |
|
|
|
return; |
|
|
|
} |
|
|
|
exit_run_ = true; |
|
|
|
cur_thread_run_nums_ = static_cast<int>(thread_list_.size()); |
|
|
|
SubRunThread(0); |
|
|
|
queue_ready_.notify_all(); |
|
|
|
task_cond_var_.notify_all(); |
|
|
|
for (auto &it : sync_run_threads_) { |
|
|
|
if (it.joinable()) { |
|
|
|
@@ -274,16 +121,6 @@ void ThreadPool::ClearThreadPool() { |
|
|
|
} |
|
|
|
} |
|
|
|
sync_run_threads_.clear(); |
|
|
|
for (auto &it : thread_list_) { |
|
|
|
if (it.joinable()) { |
|
|
|
it.join(); |
|
|
|
} |
|
|
|
} |
|
|
|
thread_list_.clear(); |
|
|
|
for (const auto &it : activate_list_) { |
|
|
|
delete it; |
|
|
|
} |
|
|
|
activate_list_.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
ThreadPool::~ThreadPool() { ClearThreadPool(); } |
|
|
|
|