diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc index bd03ef4fcf..19021a6ae4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc @@ -15,6 +15,7 @@ */ #include #include +#include "common/thread_pool.h" #include "runtime/device/cpu/cpu_device_address.h" #include "backend/kernel_compiler/cpu/random_cpu_kernel.h" @@ -41,11 +42,23 @@ void LaunchStandardNormal(int seed, int seed2, const std::vector &ou auto output = reinterpret_cast(outputs[0]->addr); size_t lens = outputs[0]->size / sizeof(float); std::normal_distribution distribution; - auto task = [&](size_t start, size_t end) { + auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); + const float block_size = 128.0; + size_t thread_num = lens < block_size * max_thread_num ? std::ceil(lens / block_size) : max_thread_num; + std::vector tasks; + size_t start = 0; + size_t once_compute_size = (lens + thread_num - 1) / thread_num; + while (start < lens) { + size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); std::default_random_engine random_generator(++RNG_seed); - StandardNormal(output, distribution, random_generator, start, end); - }; - CPUKernelUtils::ParallelFor(task, lens); + auto block = [&, start, end]() { + StandardNormal(output, distribution, random_generator, start, end); + return common::SUCCESS; + }; + tasks.emplace_back(block); + start += once_compute_size; + } + common::ThreadPool::GetInstance().SyncRun(tasks); } void RandomCPUKernel::InitKernel(const CNodePtr &kernel_node) {