Browse Source

!10304 fix thread pool sync run error

From: @kisnwang
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
59f4484a29
4 changed files with 108 additions and 13 deletions
  1. +16
    -5
      mindspore/ccsrc/backend/session/executor.cc
  2. +2
    -1
      mindspore/ccsrc/backend/session/executor_manager.cc
  3. +81
    -4
      mindspore/ccsrc/common/thread_pool.cc
  4. +9
    -3
      mindspore/ccsrc/common/thread_pool.h

+ 16
- 5
mindspore/ccsrc/backend/session/executor.cc View File

@@ -122,13 +122,13 @@ void RunGraphTask::Run() {
ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
MsException::Instance().SetException();
}
MS_LOG(INFO) << "End run graph " << graph_id_;
graph->OnRunGraphFinished();
for (auto &tensor : input_need_lock_tensors_) {
tensor->SetNeedWait(false);
}
NotifyOutputTensors(&outputs_);
ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
MS_LOG(INFO) << "End run graph " << graph_id_;
}

void RunOpTask::Run() {
@@ -213,11 +213,22 @@ std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
void Executor::OnEvent(const ExecutorEvent &event) {
if (event == ExecutorEvent::kRunGraphFinished) {
OnRunGraphFinished();
} else if (event == ExecutorEvent::kClear) {
WorkerJoin();
} else if (event == ExecutorEvent::kException) {
std::unique_lock<std::mutex> lock(task_mutex_);
while (!ready_tasks_.empty()) {
done_tasks_.emplace_back(ready_tasks_.front());
ready_tasks_.pop();
{
std::unique_lock<std::mutex> lock(task_mutex_);
while (!ready_tasks_.empty()) {
done_tasks_.emplace_back(ready_tasks_.front());
ready_tasks_.pop();
}
}
{
std::unique_lock<std::mutex> lock(pending_task_mutex_);
for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
done_tasks_.emplace_back(*iter);
}
pending_tasks_.clear();
}
}
}


+ 2
- 1
mindspore/ccsrc/backend/session/executor_manager.cc View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "backend/session/executor_manager.h"
#include "common/thread_pool.h"
namespace mindspore {
namespace session {
std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device_name, int device_id) {
@@ -40,6 +40,7 @@ void ExecutorManager::OnEvent(const ExecutorEvent &event) {
void ExecutorManager::Clear() {
OnEvent(ExecutorEvent::kClear);
executors_.clear();
common::ThreadPool::GetInstance().ClearThreadPool();
}
} // namespace session
} // namespace mindspore

+ 81
- 4
mindspore/ccsrc/common/thread_pool.cc View File

@@ -26,7 +26,7 @@ namespace common {
#ifdef ENABLE_D
const int kDeviceNum = 8;
#endif
const int kMaxThreadNum = 23;
bool Queue::Enqueue(Task *task) {
const int tail_index = tail_.load(std::memory_order_relaxed);
// queue full
@@ -64,11 +64,15 @@ ThreadPool::ThreadPool() {
#else
max_thread_num_ = process_core_num;
#endif
SetThreadPool(core_thread_num_);
if (max_thread_num_ < 1) {
max_thread_num_ = 1;
}
if (max_thread_num_ > kMaxThreadNum) {
max_thread_num_ = kMaxThreadNum;
}
}

bool ThreadPool::SetThreadPool(int config_thread_num) {
std::lock_guard<std::mutex> Lock(pool_mtx_);
if (config_thread_num > max_thread_num_) {
MS_LOG(EXCEPTION) << "Expected thread num is greater than the max thread num, expected thread num="
<< config_thread_num << ", allowed max thread num=" << max_thread_num_;
@@ -142,7 +146,65 @@ void ThreadPool::SubRunThread(int num) {
cur_thread_run_nums_ = num;
}

void ThreadPool::SyncRunLoop() {
while (true) {
Task task;
{
std::unique_lock<std::mutex> lock(task_mutex_);
task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; });
if (exit_run_) {
return;
}
task = task_queue_.front();
task_queue_.pop();
}
try {
task();
} catch (std::exception &e) {
MsException::Instance().SetException();
}
{
std::unique_lock<std::mutex> task_lock(task_mutex_);
task_finished_count_ = task_finished_count_ + 1;
}
finished_cond_var_.notify_one();
}
}

bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
if (tasks.size() == 1) {
auto ret = tasks[0]();
return ret == SUCCESS;
}
std::unique_lock<std::mutex> lock(pool_mtx_);
exit_run_ = false;
int task_num = tasks.size();
int thread_num = sync_run_threads_.size();
if (thread_num < max_thread_num_ && thread_num < task_num) {
auto new_thread_num = max_thread_num_;
if (task_num < max_thread_num_) {
new_thread_num = task_num;
}
for (int i = thread_num; i < new_thread_num; ++i) {
sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this));
}
}

for (auto &task : tasks) {
std::lock_guard<std::mutex> task_lock(task_mutex_);
task_queue_.push(task);
task_cond_var_.notify_one();
}
{
std::unique_lock<std::mutex> task_lock(task_mutex_);
finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; });
task_finished_count_ = 0;
}
return true;
}

bool ThreadPool::InnerSyncRun(const std::vector<Task> &tasks) {
std::lock_guard<std::mutex> sync_run_lock(pool_mtx_);
int thread_num = tasks.size();
if (thread_num > max_thread_num_) {
thread_num = max_thread_num_;
@@ -196,19 +258,34 @@ ThreadPool &ThreadPool::GetInstance() {
return instance;
}

ThreadPool::~ThreadPool() {
void ThreadPool::ClearThreadPool() {
std::lock_guard<std::mutex> sync_run_lock(pool_mtx_);
if (exit_run_) {
return;
}
exit_run_ = true;
cur_thread_run_nums_ = static_cast<int>(thread_list_.size());
SubRunThread(0);
queue_ready_.notify_all();
task_cond_var_.notify_all();
for (auto &it : sync_run_threads_) {
if (it.joinable()) {
it.join();
}
}
sync_run_threads_.clear();
for (auto &it : thread_list_) {
if (it.joinable()) {
it.join();
}
}
thread_list_.clear();
for (const auto &it : activate_list_) {
delete it;
}
activate_list_.clear();
}

ThreadPool::~ThreadPool() { ClearThreadPool(); }
} // namespace common
} // namespace mindspore

+ 9
- 3
mindspore/ccsrc/common/thread_pool.h View File

@@ -56,12 +56,10 @@ class ThreadPool {
~ThreadPool();
ThreadPool(const ThreadPool &) = delete;
ThreadPool &operator=(const ThreadPool &) = delete;

static ThreadPool &GetInstance();
// Use the tasks' size of threads to execute these tasks, one thread execute one task.
bool SyncRun(const std::vector<Task> &tasks);

size_t GetSyncRunThreadNum() { return max_thread_num_; }
void ClearThreadPool();

private:
ThreadPool();
@@ -70,6 +68,8 @@ class ThreadPool {
void AddRunThread(int num);
void SubRunThread(int num);
bool CheckResult();
bool InnerSyncRun(const std::vector<Task> &tasks);
void SyncRunLoop();

int cur_thread_nums_{0};
int cur_thread_run_nums_{0};
@@ -83,6 +83,12 @@ class ThreadPool {
std::vector<std::thread> thread_list_{};
std::vector<std::shared_ptr<Queue>> queue_list_{};
std::vector<std::pair<int, std::pair<bool, int>>> error_info_{};
std::queue<Task> task_queue_;
std::mutex task_mutex_;
std::condition_variable task_cond_var_;
int task_finished_count_{0};
std::condition_variable finished_cond_var_;
std::vector<std::thread> sync_run_threads_{};
};
} // namespace common
} // namespace mindspore


Loading…
Cancel
Save