From 74e938cc7767f7f7f172664ebc8aaff7f06a6aec Mon Sep 17 00:00:00 2001 From: yangjie159 Date: Wed, 16 Jun 2021 22:24:01 +0800 Subject: [PATCH] fix threadpool when not use mindrt --- .../mindrt/src/thread/actor_threadpool.cc | 18 +++---- .../core/mindrt/src/thread/threadpool.cc | 51 ++++++++++--------- mindspore/core/mindrt/src/thread/threadpool.h | 11 ++-- 3 files changed, 45 insertions(+), 35 deletions(-) diff --git a/mindspore/core/mindrt/src/thread/actor_threadpool.cc b/mindspore/core/mindrt/src/thread/actor_threadpool.cc index 3989a1b041..8404a93d43 100644 --- a/mindspore/core/mindrt/src/thread/actor_threadpool.cc +++ b/mindspore/core/mindrt/src/thread/actor_threadpool.cc @@ -36,14 +36,13 @@ ActorThreadPool::~ActorThreadPool() { void ActorThreadPool::AsyncRunMultiTask(Worker *worker) { THREAD_RETURN_IF_NULL(worker); while (alive_) { - if (RunLocalKernelTask(worker) || RunPoolQueueActorTask()) { - // only run either local KernelTask or PoolQueue ActorTask - } else { + // only run either local KernelTask or PoolQueue ActorTask + bool busy = RunLocalKernelTask(worker) || RunPoolQueueActorTask(); + if (!busy) { // wait until Actor enqueue or distribute KernelTask - worker->running = false; std::unique_lock _l(worker->mutex); - worker->cond_var.wait( - _l, [&] { return worker->task != nullptr || (worker->running && !actor_queue_.empty()) || !alive_; }); + worker->status = kThreadIdle; + worker->cond_var.wait(_l, [&] { return worker->status == kThreadBusy || !alive_; }); } } } @@ -75,10 +74,11 @@ void ActorThreadPool::EnqueReadyActor(const ActorReference &actor) { actor_queue_.push(actor); } THREAD_INFO("actor[%s] enqueue success", actor->GetAID().Name().c_str()); - // active one free actor thread + // active one idle actor thread if exist for (size_t i = 0; i < actor_thread_num_; ++i) { - bool expected = false; - if (workers_[i]->running.compare_exchange_strong(expected, true)) { + std::lock_guard _l(workers_[i]->mutex); + if (workers_[i]->status == kThreadIdle) { + workers_[i]->status = kThreadBusy; workers_[i]->cond_var.notify_one(); break; } diff --git a/mindspore/core/mindrt/src/thread/threadpool.cc b/mindspore/core/mindrt/src/thread/threadpool.cc index 8e2ccabd27..11e85ed0a9 100644 --- a/mindspore/core/mindrt/src/thread/threadpool.cc +++ b/mindspore/core/mindrt/src/thread/threadpool.cc @@ -68,9 +68,8 @@ void ThreadPool::AsyncRunTask(Worker *worker) const { } if (worker->spin >= kDefaultSpinCount) { // wait until distribute KernelTask - worker->spin = 0; std::unique_lock _l(worker->mutex); - worker->cond_var.wait(_l, [&] { return worker->running || !alive_; }); + worker->cond_var.wait(_l, [&] { return worker->status == kThreadBusy || !alive_; }); } } } @@ -78,7 +77,7 @@ void ThreadPool::AsyncRunTask(Worker *worker) const { void ThreadPool::YieldAndDeactive(Worker *worker) const { // deactivate this worker only on the first entry if (worker->spin == 0) { - worker->running = false; + worker->status.store(kThreadIdle); } worker->spin++; std::this_thread::yield(); @@ -101,12 +100,8 @@ int ThreadPool::ParallelLaunch(const Func &func, Content content, int task_num) // if the task num is greater than the KernelThread num THREAD_INFO("launch: %d", task_num); Task task = Task(func, content); - Worker *curr = CurrentWorker(); - if (curr == nullptr) { - SyncRunTask(&task, task_num); - } else { - DistributeTask(&task, task_num); - } + + DistributeTask(&task, task_num); // synchronization // wait until the finished is equal to task_num while (task.finished != task_num) { @@ -134,28 +129,34 @@ void ThreadPool::SyncRunTask(Task *task, int task_num) const { void ThreadPool::DistributeTask(Task *task, int task_num) const { Worker *curr = CurrentWorker(); - THREAD_RETURN_IF_NULL(curr); - - int count = 1; + // if the current thread isn't nullptr, that is the curr is a ActorThread, + // then the count is equal to 1. otherwise the count is equal to 0 + int count = curr != nullptr ? 1 : 0; int sum_frequency = 0; std::vector assigned; int num = static_cast(workers_.size()) - 1; for (int i = num; i >= 0 && count < task_num; --i) { - bool expected = false; - if (workers_[i]->running.compare_exchange_strong(expected, true)) { + int expected = kThreadIdle; + if (workers_[i]->status.compare_exchange_strong(expected, kThreadHeld)) { assigned.push_back(workers_[i]); sum_frequency += workers_[i]->frequency; count++; } } - assigned.push_back(curr); - for (; count < task_num; ++count) { + // when there are not enough free threads, + // distribute other tasks to the master thread + if (curr != nullptr) { assigned.push_back(curr); - sum_frequency += curr->frequency; + for (; count < task_num; ++count) { + assigned.push_back(curr); + sum_frequency += curr->frequency; + } + } else if (assigned.size() != static_cast(task_num)) { + SyncRunTask(task, task_num); + return; } - CalculateScales(assigned, sum_frequency); - ActiveWorkers(assigned, task, task_num); + ActiveWorkers(assigned, task, task_num, curr); } void ThreadPool::CalculateScales(const std::vector &assigned, int sum_frequency) const { @@ -170,13 +171,17 @@ void ThreadPool::CalculateScales(const std::vector &assigned, int sum_ } } -void ThreadPool::ActiveWorkers(const std::vector &workers, Task *task, int task_num) const { - Worker *curr = workers.back(); +void ThreadPool::ActiveWorkers(const std::vector &workers, Task *task, int task_num, + const Worker *curr) const { for (int i = 0; i < task_num; ++i) { Worker *worker = workers[i]; THREAD_RETURN_IF_NULL(worker); - worker->task_id.store(i, std::memory_order_relaxed); - worker->task.store(task, std::memory_order_relaxed); + { + std::lock_guard _l(worker->mutex); + worker->task_id.store(i, std::memory_order_relaxed); + worker->task.store(task, std::memory_order_relaxed); + worker->status = kThreadBusy; + } worker->cond_var.notify_one(); if (worker == curr) { RunLocalKernelTask(worker); diff --git a/mindspore/core/mindrt/src/thread/threadpool.h b/mindspore/core/mindrt/src/thread/threadpool.h index cb16991ef5..94b5ea17ca 100644 --- a/mindspore/core/mindrt/src/thread/threadpool.h +++ b/mindspore/core/mindrt/src/thread/threadpool.h @@ -45,9 +45,14 @@ typedef struct Task { std::atomic_int status{THREAD_OK}; // return status, RET_OK } Task; +// busy, the thread is running task +// held, the thread has been marked as occupied +// idle, the thread is waiting +enum ThreadStatus { kThreadBusy = 0, kThreadHeld = 1, kThreadIdle = 2 }; + typedef struct Worker { std::thread thread; - std::atomic_bool running{false}; + std::atomic_int status{kThreadBusy}; std::mutex mutex; std::condition_variable cond_var; std::atomic task{nullptr}; @@ -84,9 +89,9 @@ class ThreadPool { void DistributeTask(Task *task, int task_num) const; void CalculateScales(const std::vector &workers, int sum_frequency) const; - void ActiveWorkers(const std::vector &workers, Task *task, int task_num) const; - + void ActiveWorkers(const std::vector &workers, Task *task, int task_num, const Worker *curr) const; void YieldAndDeactive(Worker *worker) const; + bool RunLocalKernelTask(Worker *worker) const; Worker *CurrentWorker() const;