From: @kisnwang Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -20,15 +20,11 @@ | |||
| #include <memory> | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "common/thread_pool.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr size_t kAdamDeltaInputSize = 9; | |||
| #ifdef ENABLE_D | |||
| constexpr size_t kUsedThreadNum = 23; | |||
| #else | |||
| constexpr size_t kUsedThreadNum = 8; | |||
| #endif | |||
| namespace { | |||
| struct ComputeParam { | |||
| float *delta_{nullptr}; | |||
| @@ -139,13 +135,13 @@ bool AdamDeltaCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| auto grad = reinterpret_cast<float *>(inputs[8]->addr); | |||
| auto delta = reinterpret_cast<float *>(outputs[0]->addr); | |||
| lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); | |||
| size_t thread_num = kUsedThreadNum; | |||
| size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||
| if (elem_num_ < thread_num) { | |||
| thread_num = elem_num_; | |||
| } | |||
| std::vector<std::thread> threads; | |||
| std::vector<common::Task> tasks; | |||
| std::vector<std::shared_ptr<ComputeParam>> thread_params; | |||
| threads.reserve(thread_num); | |||
| tasks.reserve(thread_num); | |||
| size_t end = 0; | |||
| size_t offset = elem_num_ / thread_num; | |||
| @@ -166,12 +162,14 @@ bool AdamDeltaCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| if (i < left) { | |||
| end += 1; | |||
| } | |||
| threads.emplace_back(std::thread(ComputeWeightDelta, params, start, end)); | |||
| auto task = [¶ms, start, end]() { | |||
| ComputeWeightDelta(params, start, end); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| thread_params.emplace_back(params); | |||
| } | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| threads[i].join(); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| @@ -18,15 +18,11 @@ | |||
| #include "backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "ir/primitive.h" | |||
| #include "common/thread_pool.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| #ifdef ENABLE_D | |||
| constexpr size_t kUsedThreadNum = 23; | |||
| #else | |||
| constexpr size_t kUsedThreadNum = 8; | |||
| #endif | |||
| template <typename T> | |||
| void LookUpTableTask(const float *input_addr, const T *indices_addr, float *output_addr, size_t indices_lens, | |||
| size_t outer_dim_size, T offset, size_t first_dim_size) { | |||
| @@ -98,8 +94,9 @@ void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr | |||
| auto indices_addr = reinterpret_cast<T *>(inputs[1]->addr); | |||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | |||
| size_t thread_num = indices_lens_ / 10000 + 1; | |||
| thread_num = thread_num > kUsedThreadNum ? kUsedThreadNum : thread_num; | |||
| std::thread threads[kUsedThreadNum]; | |||
| auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||
| thread_num = thread_num > max_thread_num ? max_thread_num : thread_num; | |||
| std::vector<common::Task> tasks; | |||
| size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num; | |||
| size_t i; | |||
| size_t task_offset = 0; | |||
| @@ -109,17 +106,18 @@ void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr | |||
| break; | |||
| } | |||
| MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; | |||
| threads[i] = std::thread(LookUpTableTask<T>, input_addr, indices_addr + task_offset, | |||
| output_addr + task_offset * outer_dim_size_, task_proc_lens, outer_dim_size_, offset_, | |||
| first_dim_size_); | |||
| auto task = [input_addr, indices_addr, output_addr, task_offset, task_proc_lens, this]() { | |||
| LookUpTableTask<T>(input_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size_, | |||
| task_proc_lens, outer_dim_size_, offset_, first_dim_size_); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| task_offset += task_proc_lens; | |||
| if (task_offset + task_proc_lens > indices_lens_) { | |||
| task_proc_lens = indices_lens_ - task_offset; | |||
| } | |||
| } | |||
| for (size_t j = 0; j < i; j++) { | |||
| threads[j].join(); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| } | |||
| bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| @@ -22,11 +22,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| #ifdef ENABLE_D | |||
| constexpr size_t kUsedThreadNum = 23; | |||
| #else | |||
| constexpr size_t kUsedThreadNum = 8; | |||
| #endif | |||
| template <typename T> | |||
| void Compute(const ComputeParams<T> *params, const size_t start, const size_t end) { | |||
| MS_EXCEPTION_IF_NULL(params); | |||
| @@ -120,19 +115,20 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &input | |||
| params.indices_unit_rank_ = indices_unit_rank_; | |||
| params.out_strides_ = &out_strides_; | |||
| std::vector<Task> tasks; | |||
| std::vector<common::Task> tasks; | |||
| size_t start = 0; | |||
| size_t once_compute_size = (num_units_ + kUsedThreadNum - 1) / kUsedThreadNum; | |||
| auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||
| size_t once_compute_size = (num_units_ + max_thread_num - 1) / max_thread_num; | |||
| while (start < num_units_) { | |||
| size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size); | |||
| auto task = [¶ms, start, end]() -> int { | |||
| Compute<T>(¶ms, start, end); | |||
| return SUCCESS; | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| start += once_compute_size; | |||
| } | |||
| ThreadPool::GetInstance()->LaunchMultipleTask(tasks); | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size); | |||
| if (ret != 0) { | |||
| @@ -18,20 +18,14 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <thread> | |||
| #include <unordered_map> | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| #include "common/thread_pool.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| #ifdef ENABLE_D | |||
| constexpr size_t kUsedThreadNum = 23; | |||
| #else | |||
| constexpr size_t kUsedThreadNum = 8; | |||
| #endif | |||
| template <typename T> | |||
| struct SparseGradient { | |||
| float *value_{nullptr}; | |||
| @@ -100,7 +94,7 @@ class SparseOptimizerCPUKernel : public CPUKernel { | |||
| static void BucketReduceSparseGradient(const ReduceSparseGradientParam<T> ¶m) { | |||
| MS_LOG(DEBUG) << "Start"; | |||
| MS_EXCEPTION_IF_NULL(param.input_grad_); | |||
| size_t thread_num = kUsedThreadNum; | |||
| size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||
| if (param.input_grad_->indices_size_ < thread_num) { | |||
| thread_num = param.input_grad_->indices_size_; | |||
| } | |||
| @@ -125,18 +119,21 @@ class SparseOptimizerCPUKernel : public CPUKernel { | |||
| template <typename T> | |||
| void MultiThreadCompute(const MultiThreadComputeFunc<T> &func, MultiThreadComputeParams<T> *params, | |||
| size_t total_compute_size) const { | |||
| std::vector<std::thread> threads; | |||
| threads.reserve(kUsedThreadNum); | |||
| std::vector<common::Task> tasks; | |||
| auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||
| tasks.reserve(max_thread_num); | |||
| size_t start = 0; | |||
| size_t once_compute_size = (total_compute_size + kUsedThreadNum - 1) / kUsedThreadNum; | |||
| size_t once_compute_size = (total_compute_size + max_thread_num - 1) / max_thread_num; | |||
| while (start < total_compute_size) { | |||
| size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size); | |||
| threads.emplace_back(std::thread(func, params, start, end)); | |||
| auto task = [&func, ¶ms, start, end]() { | |||
| func(params, start, end); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| start += once_compute_size; | |||
| } | |||
| for (size_t i = 0; i < threads.size(); ++i) { | |||
| threads[i].join(); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| } | |||
| private: | |||
| @@ -173,8 +170,8 @@ class SparseOptimizerCPUKernel : public CPUKernel { | |||
| } | |||
| size_t thread_indices_size = input_grad->indices_size_ / param.thread_num_; | |||
| size_t left_indices_size = input_grad->indices_size_ % param.thread_num_; | |||
| std::vector<std::thread> threads; | |||
| threads.reserve(param.thread_num_); | |||
| std::vector<common::Task> tasks; | |||
| tasks.reserve(param.thread_num_); | |||
| segments.reserve(param.thread_num_); | |||
| size_t current_indices_offset = 0; | |||
| @@ -188,14 +185,14 @@ class SparseOptimizerCPUKernel : public CPUKernel { | |||
| segments[i]->value_ = input_grad->value_ + current_indices_offset * param.value_stride_; | |||
| segments[i]->indices_ = input_grad->indices_ + current_indices_offset; | |||
| segments[i]->indices_size_ = indices_size; | |||
| threads.emplace_back( | |||
| std::thread(CalculateEachBucketSize<T>, segments[i], param.max_index_, segment_bucket_sizes[i].get())); | |||
| auto task = [&segments, ¶m, &segment_bucket_sizes, i]() { | |||
| CalculateEachBucketSize<T>(segments[i], param.max_index_, segment_bucket_sizes[i].get()); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| current_indices_offset += indices_size; | |||
| } | |||
| for (size_t i = 0; i < param.thread_num_; ++i) { | |||
| threads[i].join(); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| } | |||
| template <typename T> | |||
| @@ -263,17 +260,18 @@ class SparseOptimizerCPUKernel : public CPUKernel { | |||
| } | |||
| each_thread_buckets.emplace_back(thread_buckets); | |||
| } | |||
| std::vector<std::thread> threads; | |||
| threads.reserve(thread_num); | |||
| std::vector<common::Task> tasks; | |||
| tasks.reserve(thread_num); | |||
| current_indices_offset = 0; | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| threads.emplace_back( | |||
| std::thread(CopySegmentIndicesToBucket<T>, param, segments[i], current_indices_offset, each_thread_buckets[i])); | |||
| auto task = [¶m, &segments, &each_thread_buckets, i, current_indices_offset]() { | |||
| CopySegmentIndicesToBucket<T>(param, segments[i], current_indices_offset, each_thread_buckets[i]); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| current_indices_offset += segments[i]->indices_size_; | |||
| } | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| threads[i].join(); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| } | |||
| template <typename T> | |||
| @@ -381,8 +379,8 @@ class SparseOptimizerCPUKernel : public CPUKernel { | |||
| MS_EXCEPTION_IF_NULL(reduced_buckets_ptr); | |||
| auto &reduced_buckets = *reduced_buckets_ptr; | |||
| size_t thread_num = buckets.size(); | |||
| std::vector<std::thread> threads; | |||
| threads.reserve(thread_num); | |||
| std::vector<common::Task> tasks; | |||
| tasks.reserve(thread_num); | |||
| size_t current_indices_offset = 0; | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| @@ -390,16 +388,18 @@ class SparseOptimizerCPUKernel : public CPUKernel { | |||
| reduced_buckets[i]->value_ = param.workspace_grad_->value_ + current_indices_offset * param.value_stride_; | |||
| reduced_buckets[i]->indices_ = param.workspace_grad_->indices_ + current_indices_offset; | |||
| reduced_buckets[i]->indices_size_ = buckets[i]->indices_size_; | |||
| if (param.use_sort_reduce_) { | |||
| threads.emplace_back(std::thread(SortAndReduceBucketSparseGradient<T>, param, buckets[i], reduced_buckets[i])); | |||
| } else { | |||
| threads.emplace_back(std::thread(ReduceBucketSparseGradient<T>, param, buckets[i], reduced_buckets[i])); | |||
| } | |||
| auto task = [¶m, &buckets, &reduced_buckets, i]() { | |||
| if (param.use_sort_reduce_) { | |||
| SortAndReduceBucketSparseGradient<T>(param, buckets[i], reduced_buckets[i]); | |||
| } else { | |||
| ReduceBucketSparseGradient<T>(param, buckets[i], reduced_buckets[i]); | |||
| } | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| current_indices_offset += buckets[i]->indices_size_; | |||
| } | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| threads[i].join(); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| } | |||
| template <typename T> | |||
| @@ -20,11 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| const size_t kUseBucketUniqueSize = 100000; | |||
| #ifdef ENABLE_D | |||
| constexpr size_t kUniqueThreadNum = 23; | |||
| #else | |||
| constexpr size_t kUniqueThreadNum = 8; | |||
| #endif | |||
| void UniqueCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| node_ = kernel_node; | |||
| CheckParam(kernel_node); | |||
| @@ -88,7 +83,7 @@ void UniqueCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const | |||
| params->input_size_ = input_size_; | |||
| params->output_size_ = 0; | |||
| params->need_sort_ = true; | |||
| params->thread_num_ = kUniqueThreadNum; | |||
| params->thread_num_ = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||
| if (input_size_ < kUseBucketUniqueSize) { | |||
| Unique(params); | |||
| } else { | |||
| @@ -23,6 +23,7 @@ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| #include "common/thread_pool.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -104,11 +105,11 @@ class UniqueCPUKernel : public CPUKernel { | |||
| } | |||
| IndexType thread_data_size = input_size / thread_num; | |||
| size_t left_data_size = input_size % thread_num; | |||
| std::vector<std::thread> threads; | |||
| threads.reserve(thread_num); | |||
| segments.reserve(thread_num); | |||
| segment_bucket_sizes.reserve(thread_num); | |||
| IndexType current_offset = 0; | |||
| std::vector<common::Task> tasks; | |||
| tasks.reserve(thread_num); | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| segment_bucket_sizes.emplace_back(std::make_shared<std::vector<IndexType>>(thread_num, 0)); | |||
| IndexType data_size = thread_data_size; | |||
| @@ -119,13 +120,14 @@ class UniqueCPUKernel : public CPUKernel { | |||
| segments[i]->input_ = params->input_ + current_offset; | |||
| segments[i]->input_size_ = data_size; | |||
| segments[i]->thread_num_ = thread_num; | |||
| threads.emplace_back( | |||
| std::thread(CalculateEachBucketSize<DataType, IndexType>, segments[i], segment_bucket_sizes[i].get())); | |||
| auto task = [&segments, &segment_bucket_sizes, i]() { | |||
| CalculateEachBucketSize<DataType, IndexType>(segments[i], segment_bucket_sizes[i].get()); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| current_offset += data_size; | |||
| } | |||
| for (size_t i = 0; i < params->thread_num_; ++i) { | |||
| threads[i].join(); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| } | |||
| template <typename DataType, typename IndexType> | |||
| @@ -214,18 +216,19 @@ class UniqueCPUKernel : public CPUKernel { | |||
| } | |||
| thread_buckets.emplace_back(local_buckets); | |||
| } | |||
| std::vector<std::thread> threads; | |||
| threads.reserve(thread_num); | |||
| std::vector<common::Task> tasks; | |||
| tasks.reserve(thread_num); | |||
| current_offset = 0; | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| MS_EXCEPTION_IF_NULL(segments[i]); | |||
| threads.emplace_back( | |||
| std::thread(SegmentToBuckets<DataType, IndexType>, segments[i], current_offset, thread_buckets[i])); | |||
| auto task = [&segments, &thread_buckets, current_offset, i]() { | |||
| SegmentToBuckets<DataType, IndexType>(segments[i], current_offset, thread_buckets[i]); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| current_offset += segments[i]->input_size_; | |||
| } | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| threads[i].join(); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| @@ -288,14 +291,16 @@ class UniqueCPUKernel : public CPUKernel { | |||
| static void UniqueEachBucket(const std::vector<std::shared_ptr<UniqueParam<DataType, IndexType>>> &buckets) { | |||
| MS_LOG(DEBUG) << "Start"; | |||
| size_t thread_num = buckets.size(); | |||
| std::vector<std::thread> threads; | |||
| threads.reserve(thread_num); | |||
| std::vector<common::Task> tasks; | |||
| tasks.reserve(thread_num); | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| threads.emplace_back(std::thread(Unique<DataType, IndexType>, buckets[i])); | |||
| } | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| threads[i].join(); | |||
| auto task = [&buckets, i]() { | |||
| Unique<DataType, IndexType>(buckets[i]); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| @@ -342,15 +347,16 @@ class UniqueCPUKernel : public CPUKernel { | |||
| } | |||
| result->output_size_ = current_size; | |||
| std::vector<std::thread> threads; | |||
| threads.reserve(thread_num); | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| threads.emplace_back( | |||
| std::thread(TransformBucketReverseIndices<DataType, IndexType>, buckets[i], result, bucket_offsets[i])); | |||
| } | |||
| std::vector<common::Task> tasks; | |||
| tasks.reserve(thread_num); | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| threads[i].join(); | |||
| auto task = [&buckets, i, result, &bucket_offsets]() { | |||
| TransformBucketReverseIndices<DataType, IndexType>(buckets[i], result, bucket_offsets[i]); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| @@ -16,10 +16,13 @@ | |||
| #include "common/thread_pool.h" | |||
| #include <algorithm> | |||
| #include <exception> | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/convert_utils_base.h" | |||
| #include "utils/ms_exception.h" | |||
| namespace mindspore { | |||
| namespace common { | |||
| #ifdef ENABLE_D | |||
| const int kDeviceNum = 8; | |||
| #endif | |||
| @@ -52,9 +55,14 @@ bool Queue::Dequeue(Task **out) { | |||
| } | |||
| ThreadPool::ThreadPool() { | |||
| int process_core_num = std::thread::hardware_concurrency() - 1; | |||
| if (process_core_num < 1) { | |||
| process_core_num = 1; | |||
| } | |||
| #ifdef ENABLE_D | |||
| auto cpu_core_num = std::thread::hardware_concurrency(); | |||
| max_thread_num_ = cpu_core_num / kDeviceNum; | |||
| max_thread_num_ = process_core_num / kDeviceNum; | |||
| #else | |||
| max_thread_num_ = process_core_num; | |||
| #endif | |||
| SetThreadPool(core_thread_num_); | |||
| } | |||
| @@ -81,7 +89,13 @@ void ThreadPool::AddNewThread(int add_num) { | |||
| while (!exit_run_) { | |||
| while (*active) { | |||
| if (queue->Dequeue(&task)) { | |||
| auto ret = (*task)(); | |||
| int ret; | |||
| try { | |||
| ret = (*task)(); | |||
| } catch (std::exception &e) { | |||
| ret = FAIL; | |||
| MsException::Instance().SetException(); | |||
| } | |||
| if (ret != SUCCESS) { | |||
| error_info_.emplace_back(std::make_pair(i, std::make_pair(false, ret))); | |||
| } | |||
| @@ -128,7 +142,7 @@ void ThreadPool::SubRunThread(int num) { | |||
| cur_thread_run_nums_ = num; | |||
| } | |||
| bool ThreadPool::LaunchMultipleTask(const std::vector<Task> &tasks) { | |||
| bool ThreadPool::SyncRun(const std::vector<Task> &tasks) { | |||
| int thread_num = tasks.size(); | |||
| if (thread_num > max_thread_num_) { | |||
| thread_num = max_thread_num_; | |||
| @@ -177,14 +191,14 @@ bool ThreadPool::CheckResult() { | |||
| return succ_flag; | |||
| } | |||
| ThreadPool *ThreadPool::GetInstance() { | |||
| ThreadPool &ThreadPool::GetInstance() { | |||
| static ThreadPool instance; | |||
| return &instance; | |||
| return instance; | |||
| } | |||
| ThreadPool::~ThreadPool() { | |||
| cur_thread_run_nums_ = static_cast<int>(thread_list_.size()); | |||
| exit_run_ = true; | |||
| cur_thread_run_nums_ = static_cast<int>(thread_list_.size()); | |||
| SubRunThread(0); | |||
| queue_ready_.notify_all(); | |||
| for (auto &it : thread_list_) { | |||
| @@ -196,4 +210,5 @@ ThreadPool::~ThreadPool() { | |||
| delete it; | |||
| } | |||
| } | |||
| } // namespace common | |||
| } // namespace mindspore | |||
| @@ -31,6 +31,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace common { | |||
| const int kCoreThreadNum = 3; | |||
| const int kDefaultMaxThreadNum = 8; | |||
| enum Status { FAIL = -1, SUCCESS = 0 }; | |||
| @@ -56,9 +57,11 @@ class ThreadPool { | |||
| ThreadPool(const ThreadPool &) = delete; | |||
| ThreadPool &operator=(const ThreadPool &) = delete; | |||
| static ThreadPool *GetInstance(); | |||
| static ThreadPool &GetInstance(); | |||
| // Use the tasks' size of threads to execute these tasks, one thread execute one task. | |||
| bool LaunchMultipleTask(const std::vector<Task> &tasks); | |||
| bool SyncRun(const std::vector<Task> &tasks); | |||
| size_t GetSyncRunThreadNum() { return max_thread_num_; } | |||
| private: | |||
| ThreadPool(); | |||
| @@ -81,6 +84,7 @@ class ThreadPool { | |||
| std::vector<std::shared_ptr<Queue>> queue_list_{}; | |||
| std::vector<std::pair<int, std::pair<bool, int>>> error_info_{}; | |||
| }; | |||
| } // namespace common | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_COMMON_THREAD_POOL_H_ | |||