diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc index 5d73a25bff..158fd47e02 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc @@ -32,8 +32,7 @@ void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) { } template -bool ConcatCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, +bool ConcatCPUKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { auto node_ = node_wpt_.lock(); if (!node_) { @@ -71,7 +70,7 @@ bool ConcatCPUKernel::Launch(const std::vector &inputs, } template -void ConcatCPUKernel::CheckParam(const CNodePtr &kernel_node) { +void ConcatCPUKernel::CheckParam(const CNodePtr &kernel_node) const { size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ConcatCPUKernel needs 1 output."; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h index 0a63dd7ea1..5c30089bbe 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h @@ -34,7 +34,7 @@ class ConcatCPUKernel : public CPUKernel { const std::vector &outputs) override; private: - void CheckParam(const CNodePtr &kernel_node); + void CheckParam(const CNodePtr &kernel_node) const; int axis_ = 0; CNodeWeakPtr node_wpt_; }; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc index 542049126b..bd03ef4fcf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc @@ -26,6 +26,7 @@ void StandardNormal(float *output, std::normal_distribution distribution, output[i] = distribution(random_generator); } } + void LaunchStandardNormal(int seed, int seed2, const std::vector &outputs) { unsigned int RNG_seed; std::random_device rd; @@ -38,33 +39,13 @@ void LaunchStandardNormal(int seed, int seed2, const std::vector &ou } auto output = reinterpret_cast(outputs[0]->addr); - // multithreading size_t lens = outputs[0]->size / sizeof(float); - auto max_thread_num = std::thread::hardware_concurrency(); - size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; - if (thread_num < 1) { - MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; - return; - } - std::vector threads; - threads.reserve(thread_num); - size_t start = 0; - size_t once_compute_size = (lens + thread_num - 1) / thread_num; - if (once_compute_size < 1) { - MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; - return; - } std::normal_distribution distribution; - while (start < lens) { - // avoid different threads using the same seed to generate the same random number + auto task = [&](size_t start, size_t end) { std::default_random_engine random_generator(++RNG_seed); - size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); - threads.emplace_back(std::thread(StandardNormal, output, distribution, random_generator, start, end)); - start += once_compute_size; - } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); - } + StandardNormal(output, distribution, random_generator, start, end); + }; + CPUKernelUtils::ParallelFor(task, lens); } void RandomCPUKernel::InitKernel(const CNodePtr &kernel_node) { @@ -91,8 +72,7 @@ void RandomCPUKernel::InitKernel(const CNodePtr &kernel_node) { seed2_ = LongToInt(GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2"))); } -bool RandomCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, +bool RandomCPUKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { switch (random_op_type_) { case RANDOM_OP_NORMAL: {