Browse Source

!10005 add sync run for thread pool

From: @kisnwang
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
27b337a4d2
8 changed files with 128 additions and 116 deletions
  1. +10
    -12
      mindspore/ccsrc/backend/kernel_compiler/cpu/adam_delta_cpu_kernel.cc
  2. +11
    -13
      mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc
  3. +5
    -9
      mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc
  4. +40
    -40
      mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h
  5. +1
    -6
      mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.cc
  6. +33
    -27
      mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.h
  7. +22
    -7
      mindspore/ccsrc/common/thread_pool.cc
  8. +6
    -2
      mindspore/ccsrc/common/thread_pool.h

+ 10
- 12
mindspore/ccsrc/backend/kernel_compiler/cpu/adam_delta_cpu_kernel.cc View File

@@ -20,15 +20,11 @@
#include <memory>
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"

namespace mindspore {
namespace kernel {
constexpr size_t kAdamDeltaInputSize = 9;
#ifdef ENABLE_D
constexpr size_t kUsedThreadNum = 23;
#else
constexpr size_t kUsedThreadNum = 8;
#endif
namespace {
struct ComputeParam {
float *delta_{nullptr};
@@ -139,13 +135,13 @@ bool AdamDeltaCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto grad = reinterpret_cast<float *>(inputs[8]->addr);
auto delta = reinterpret_cast<float *>(outputs[0]->addr);
lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
size_t thread_num = kUsedThreadNum;
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
if (elem_num_ < thread_num) {
thread_num = elem_num_;
}
std::vector<std::thread> threads;
std::vector<common::Task> tasks;
std::vector<std::shared_ptr<ComputeParam>> thread_params;
threads.reserve(thread_num);
tasks.reserve(thread_num);

size_t end = 0;
size_t offset = elem_num_ / thread_num;
@@ -166,12 +162,14 @@ bool AdamDeltaCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
if (i < left) {
end += 1;
}
threads.emplace_back(std::thread(ComputeWeightDelta, params, start, end));
auto task = [&params, start, end]() {
ComputeWeightDelta(params, start, end);
return common::SUCCESS;
};
tasks.emplace_back(task);
thread_params.emplace_back(params);
}
for (size_t i = 0; i < thread_num; ++i) {
threads[i].join();
}
common::ThreadPool::GetInstance().SyncRun(tasks);
return true;
}
} // namespace kernel


+ 11
- 13
mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc View File

@@ -18,15 +18,11 @@
#include "backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "ir/primitive.h"
#include "common/thread_pool.h"

namespace mindspore {
namespace kernel {
namespace {
#ifdef ENABLE_D
constexpr size_t kUsedThreadNum = 23;
#else
constexpr size_t kUsedThreadNum = 8;
#endif
template <typename T>
void LookUpTableTask(const float *input_addr, const T *indices_addr, float *output_addr, size_t indices_lens,
size_t outer_dim_size, T offset, size_t first_dim_size) {
@@ -98,8 +94,9 @@ void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr
auto indices_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
size_t thread_num = indices_lens_ / 10000 + 1;
thread_num = thread_num > kUsedThreadNum ? kUsedThreadNum : thread_num;
std::thread threads[kUsedThreadNum];
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
thread_num = thread_num > max_thread_num ? max_thread_num : thread_num;
std::vector<common::Task> tasks;
size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num;
size_t i;
size_t task_offset = 0;
@@ -109,17 +106,18 @@ void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr
break;
}
MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens;
threads[i] = std::thread(LookUpTableTask<T>, input_addr, indices_addr + task_offset,
output_addr + task_offset * outer_dim_size_, task_proc_lens, outer_dim_size_, offset_,
first_dim_size_);
auto task = [input_addr, indices_addr, output_addr, task_offset, task_proc_lens, this]() {
LookUpTableTask<T>(input_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size_,
task_proc_lens, outer_dim_size_, offset_, first_dim_size_);
return common::SUCCESS;
};
tasks.emplace_back(task);
task_offset += task_proc_lens;
if (task_offset + task_proc_lens > indices_lens_) {
task_proc_lens = indices_lens_ - task_offset;
}
}
for (size_t j = 0; j < i; j++) {
threads[j].join();
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}

bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,


+ 5
- 9
mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc View File

@@ -22,11 +22,6 @@
namespace mindspore {
namespace kernel {
namespace {
#ifdef ENABLE_D
constexpr size_t kUsedThreadNum = 23;
#else
constexpr size_t kUsedThreadNum = 8;
#endif
template <typename T>
void Compute(const ComputeParams<T> *params, const size_t start, const size_t end) {
MS_EXCEPTION_IF_NULL(params);
@@ -120,19 +115,20 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &input
params.indices_unit_rank_ = indices_unit_rank_;
params.out_strides_ = &out_strides_;

std::vector<Task> tasks;
std::vector<common::Task> tasks;
size_t start = 0;
size_t once_compute_size = (num_units_ + kUsedThreadNum - 1) / kUsedThreadNum;
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
size_t once_compute_size = (num_units_ + max_thread_num - 1) / max_thread_num;
while (start < num_units_) {
size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size);
auto task = [&params, start, end]() -> int {
Compute<T>(&params, start, end);
return SUCCESS;
return common::SUCCESS;
};
tasks.emplace_back(task);
start += once_compute_size;
}
ThreadPool::GetInstance()->LaunchMultipleTask(tasks);
common::ThreadPool::GetInstance().SyncRun(tasks);

auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
if (ret != 0) {


+ 40
- 40
mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_optimizer_cpu_kernel.h View File

@@ -18,20 +18,14 @@

#include <vector>
#include <memory>
#include <thread>
#include <unordered_map>
#include <algorithm>
#include <utility>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "common/thread_pool.h"
namespace mindspore {
namespace kernel {
#ifdef ENABLE_D
constexpr size_t kUsedThreadNum = 23;
#else
constexpr size_t kUsedThreadNum = 8;
#endif
template <typename T>
struct SparseGradient {
float *value_{nullptr};
@@ -100,7 +94,7 @@ class SparseOptimizerCPUKernel : public CPUKernel {
static void BucketReduceSparseGradient(const ReduceSparseGradientParam<T> &param) {
MS_LOG(DEBUG) << "Start";
MS_EXCEPTION_IF_NULL(param.input_grad_);
size_t thread_num = kUsedThreadNum;
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
if (param.input_grad_->indices_size_ < thread_num) {
thread_num = param.input_grad_->indices_size_;
}
@@ -125,18 +119,21 @@ class SparseOptimizerCPUKernel : public CPUKernel {
template <typename T>
void MultiThreadCompute(const MultiThreadComputeFunc<T> &func, MultiThreadComputeParams<T> *params,
size_t total_compute_size) const {
std::vector<std::thread> threads;
threads.reserve(kUsedThreadNum);
std::vector<common::Task> tasks;
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
tasks.reserve(max_thread_num);
size_t start = 0;
size_t once_compute_size = (total_compute_size + kUsedThreadNum - 1) / kUsedThreadNum;
size_t once_compute_size = (total_compute_size + max_thread_num - 1) / max_thread_num;
while (start < total_compute_size) {
size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size);
threads.emplace_back(std::thread(func, params, start, end));
auto task = [&func, &params, start, end]() {
func(params, start, end);
return common::SUCCESS;
};
tasks.emplace_back(task);
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}

private:
@@ -173,8 +170,8 @@ class SparseOptimizerCPUKernel : public CPUKernel {
}
size_t thread_indices_size = input_grad->indices_size_ / param.thread_num_;
size_t left_indices_size = input_grad->indices_size_ % param.thread_num_;
std::vector<std::thread> threads;
threads.reserve(param.thread_num_);
std::vector<common::Task> tasks;
tasks.reserve(param.thread_num_);
segments.reserve(param.thread_num_);

size_t current_indices_offset = 0;
@@ -188,14 +185,14 @@ class SparseOptimizerCPUKernel : public CPUKernel {
segments[i]->value_ = input_grad->value_ + current_indices_offset * param.value_stride_;
segments[i]->indices_ = input_grad->indices_ + current_indices_offset;
segments[i]->indices_size_ = indices_size;
threads.emplace_back(
std::thread(CalculateEachBucketSize<T>, segments[i], param.max_index_, segment_bucket_sizes[i].get()));
auto task = [&segments, &param, &segment_bucket_sizes, i]() {
CalculateEachBucketSize<T>(segments[i], param.max_index_, segment_bucket_sizes[i].get());
return common::SUCCESS;
};
tasks.emplace_back(task);
current_indices_offset += indices_size;
}

for (size_t i = 0; i < param.thread_num_; ++i) {
threads[i].join();
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}

template <typename T>
@@ -263,17 +260,18 @@ class SparseOptimizerCPUKernel : public CPUKernel {
}
each_thread_buckets.emplace_back(thread_buckets);
}
std::vector<std::thread> threads;
threads.reserve(thread_num);
std::vector<common::Task> tasks;
tasks.reserve(thread_num);
current_indices_offset = 0;
for (size_t i = 0; i < thread_num; ++i) {
threads.emplace_back(
std::thread(CopySegmentIndicesToBucket<T>, param, segments[i], current_indices_offset, each_thread_buckets[i]));
auto task = [&param, &segments, &each_thread_buckets, i, current_indices_offset]() {
CopySegmentIndicesToBucket<T>(param, segments[i], current_indices_offset, each_thread_buckets[i]);
return common::SUCCESS;
};
tasks.emplace_back(task);
current_indices_offset += segments[i]->indices_size_;
}
for (size_t i = 0; i < thread_num; ++i) {
threads[i].join();
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}

template <typename T>
@@ -381,8 +379,8 @@ class SparseOptimizerCPUKernel : public CPUKernel {
MS_EXCEPTION_IF_NULL(reduced_buckets_ptr);
auto &reduced_buckets = *reduced_buckets_ptr;
size_t thread_num = buckets.size();
std::vector<std::thread> threads;
threads.reserve(thread_num);
std::vector<common::Task> tasks;
tasks.reserve(thread_num);

size_t current_indices_offset = 0;
for (size_t i = 0; i < thread_num; ++i) {
@@ -390,16 +388,18 @@ class SparseOptimizerCPUKernel : public CPUKernel {
reduced_buckets[i]->value_ = param.workspace_grad_->value_ + current_indices_offset * param.value_stride_;
reduced_buckets[i]->indices_ = param.workspace_grad_->indices_ + current_indices_offset;
reduced_buckets[i]->indices_size_ = buckets[i]->indices_size_;
if (param.use_sort_reduce_) {
threads.emplace_back(std::thread(SortAndReduceBucketSparseGradient<T>, param, buckets[i], reduced_buckets[i]));
} else {
threads.emplace_back(std::thread(ReduceBucketSparseGradient<T>, param, buckets[i], reduced_buckets[i]));
}
auto task = [&param, &buckets, &reduced_buckets, i]() {
if (param.use_sort_reduce_) {
SortAndReduceBucketSparseGradient<T>(param, buckets[i], reduced_buckets[i]);
} else {
ReduceBucketSparseGradient<T>(param, buckets[i], reduced_buckets[i]);
}
return common::SUCCESS;
};
tasks.emplace_back(task);
current_indices_offset += buckets[i]->indices_size_;
}
for (size_t i = 0; i < thread_num; ++i) {
threads[i].join();
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}

template <typename T>


+ 1
- 6
mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.cc View File

@@ -20,11 +20,6 @@
namespace mindspore {
namespace kernel {
const size_t kUseBucketUniqueSize = 100000;
#ifdef ENABLE_D
constexpr size_t kUniqueThreadNum = 23;
#else
constexpr size_t kUniqueThreadNum = 8;
#endif
void UniqueCPUKernel::InitKernel(const CNodePtr &kernel_node) {
node_ = kernel_node;
CheckParam(kernel_node);
@@ -88,7 +83,7 @@ void UniqueCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const
params->input_size_ = input_size_;
params->output_size_ = 0;
params->need_sort_ = true;
params->thread_num_ = kUniqueThreadNum;
params->thread_num_ = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
if (input_size_ < kUseBucketUniqueSize) {
Unique(params);
} else {


+ 33
- 27
mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.h View File

@@ -23,6 +23,7 @@
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "common/thread_pool.h"

namespace mindspore {
namespace kernel {
@@ -104,11 +105,11 @@ class UniqueCPUKernel : public CPUKernel {
}
IndexType thread_data_size = input_size / thread_num;
size_t left_data_size = input_size % thread_num;
std::vector<std::thread> threads;
threads.reserve(thread_num);
segments.reserve(thread_num);
segment_bucket_sizes.reserve(thread_num);
IndexType current_offset = 0;
std::vector<common::Task> tasks;
tasks.reserve(thread_num);
for (size_t i = 0; i < thread_num; ++i) {
segment_bucket_sizes.emplace_back(std::make_shared<std::vector<IndexType>>(thread_num, 0));
IndexType data_size = thread_data_size;
@@ -119,13 +120,14 @@ class UniqueCPUKernel : public CPUKernel {
segments[i]->input_ = params->input_ + current_offset;
segments[i]->input_size_ = data_size;
segments[i]->thread_num_ = thread_num;
threads.emplace_back(
std::thread(CalculateEachBucketSize<DataType, IndexType>, segments[i], segment_bucket_sizes[i].get()));
auto task = [&segments, &segment_bucket_sizes, i]() {
CalculateEachBucketSize<DataType, IndexType>(segments[i], segment_bucket_sizes[i].get());
return common::SUCCESS;
};
tasks.emplace_back(task);
current_offset += data_size;
}
for (size_t i = 0; i < params->thread_num_; ++i) {
threads[i].join();
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}

template <typename DataType, typename IndexType>
@@ -214,18 +216,19 @@ class UniqueCPUKernel : public CPUKernel {
}
thread_buckets.emplace_back(local_buckets);
}
std::vector<std::thread> threads;
threads.reserve(thread_num);
std::vector<common::Task> tasks;
tasks.reserve(thread_num);
current_offset = 0;
for (size_t i = 0; i < thread_num; ++i) {
MS_EXCEPTION_IF_NULL(segments[i]);
threads.emplace_back(
std::thread(SegmentToBuckets<DataType, IndexType>, segments[i], current_offset, thread_buckets[i]));
auto task = [&segments, &thread_buckets, current_offset, i]() {
SegmentToBuckets<DataType, IndexType>(segments[i], current_offset, thread_buckets[i]);
return common::SUCCESS;
};
tasks.emplace_back(task);
current_offset += segments[i]->input_size_;
}
for (size_t i = 0; i < thread_num; ++i) {
threads[i].join();
}
common::ThreadPool::GetInstance().SyncRun(tasks);
MS_LOG(DEBUG) << "End";
}

@@ -288,14 +291,16 @@ class UniqueCPUKernel : public CPUKernel {
static void UniqueEachBucket(const std::vector<std::shared_ptr<UniqueParam<DataType, IndexType>>> &buckets) {
MS_LOG(DEBUG) << "Start";
size_t thread_num = buckets.size();
std::vector<std::thread> threads;
threads.reserve(thread_num);
std::vector<common::Task> tasks;
tasks.reserve(thread_num);
for (size_t i = 0; i < thread_num; ++i) {
threads.emplace_back(std::thread(Unique<DataType, IndexType>, buckets[i]));
}
for (size_t i = 0; i < thread_num; ++i) {
threads[i].join();
auto task = [&buckets, i]() {
Unique<DataType, IndexType>(buckets[i]);
return common::SUCCESS;
};
tasks.emplace_back(task);
}
common::ThreadPool::GetInstance().SyncRun(tasks);
MS_LOG(DEBUG) << "End";
}

@@ -342,15 +347,16 @@ class UniqueCPUKernel : public CPUKernel {
}
result->output_size_ = current_size;

std::vector<std::thread> threads;
threads.reserve(thread_num);
for (size_t i = 0; i < thread_num; ++i) {
threads.emplace_back(
std::thread(TransformBucketReverseIndices<DataType, IndexType>, buckets[i], result, bucket_offsets[i]));
}
std::vector<common::Task> tasks;
tasks.reserve(thread_num);
for (size_t i = 0; i < thread_num; ++i) {
threads[i].join();
auto task = [&buckets, i, result, &bucket_offsets]() {
TransformBucketReverseIndices<DataType, IndexType>(buckets[i], result, bucket_offsets[i]);
return common::SUCCESS;
};
tasks.emplace_back(task);
}
common::ThreadPool::GetInstance().SyncRun(tasks);
MS_LOG(DEBUG) << "End";
}



+ 22
- 7
mindspore/ccsrc/common/thread_pool.cc View File

@@ -16,10 +16,13 @@

#include "common/thread_pool.h"
#include <algorithm>
#include <exception>
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"
#include "utils/ms_exception.h"

namespace mindspore {
namespace common {
#ifdef ENABLE_D
const int kDeviceNum = 8;
#endif
@@ -52,9 +55,14 @@ bool Queue::Dequeue(Task **out) {
}

ThreadPool::ThreadPool() {
int process_core_num = std::thread::hardware_concurrency() - 1;
if (process_core_num < 1) {
process_core_num = 1;
}
#ifdef ENABLE_D
auto cpu_core_num = std::thread::hardware_concurrency();
max_thread_num_ = cpu_core_num / kDeviceNum;
max_thread_num_ = process_core_num / kDeviceNum;
#else
max_thread_num_ = process_core_num;
#endif
SetThreadPool(core_thread_num_);
}
@@ -81,7 +89,13 @@ void ThreadPool::AddNewThread(int add_num) {
while (!exit_run_) {
while (*active) {
if (queue->Dequeue(&task)) {
auto ret = (*task)();
int ret;
try {
ret = (*task)();
} catch (std::exception &e) {
ret = FAIL;
MsException::Instance().SetException();
}
if (ret != SUCCESS) {
error_info_.emplace_back(std::make_pair(i, std::make_pair(false, ret)));
}
@@ -128,7 +142,7 @@ void ThreadPool::SubRunThread(int num) {
cur_thread_run_nums_ = num;
}

bool ThreadPool::LaunchMultipleTask(const std::vector<Task> &tasks) {
bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
int thread_num = tasks.size();
if (thread_num > max_thread_num_) {
thread_num = max_thread_num_;
@@ -177,14 +191,14 @@ bool ThreadPool::CheckResult() {
return succ_flag;
}

ThreadPool *ThreadPool::GetInstance() {
ThreadPool &ThreadPool::GetInstance() {
static ThreadPool instance;
return &instance;
return instance;
}

ThreadPool::~ThreadPool() {
cur_thread_run_nums_ = static_cast<int>(thread_list_.size());
exit_run_ = true;
cur_thread_run_nums_ = static_cast<int>(thread_list_.size());
SubRunThread(0);
queue_ready_.notify_all();
for (auto &it : thread_list_) {
@@ -196,4 +210,5 @@ ThreadPool::~ThreadPool() {
delete it;
}
}
} // namespace common
} // namespace mindspore

+ 6
- 2
mindspore/ccsrc/common/thread_pool.h View File

@@ -31,6 +31,7 @@
#include "utils/log_adapter.h"

namespace mindspore {
namespace common {
const int kCoreThreadNum = 3;
const int kDefaultMaxThreadNum = 8;
enum Status { FAIL = -1, SUCCESS = 0 };
@@ -56,9 +57,11 @@ class ThreadPool {
ThreadPool(const ThreadPool &) = delete;
ThreadPool &operator=(const ThreadPool &) = delete;

static ThreadPool *GetInstance();
static ThreadPool &GetInstance();
// Use the tasks' size of threads to execute these tasks, one thread execute one task.
bool LaunchMultipleTask(const std::vector<Task> &tasks);
bool SyncRun(const std::vector<Task> &tasks);

size_t GetSyncRunThreadNum() { return max_thread_num_; }

private:
ThreadPool();
@@ -81,6 +84,7 @@ class ThreadPool {
std::vector<std::shared_ptr<Queue>> queue_list_{};
std::vector<std::pair<int, std::pair<bool, int>>> error_info_{};
};
} // namespace common
} // namespace mindspore

#endif // MINDSPORE_CCSRC_COMMON_THREAD_POOL_H_

Loading…
Cancel
Save