Browse Source

!16045 Fixes RNG_seed bug in StandardNormal operator in r1.2

From: @huangbo77
Reviewed-by: @wuxuejian,@liangchenghui
Signed-off-by: @wuxuejian
pull/16045/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
48aa19e8b7
1 changed files with 17 additions and 4 deletions
  1. +17
    -4
      mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc

+ 17
- 4
mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc View File

@@ -15,6 +15,7 @@
*/
#include <random>
#include <thread>
#include "common/thread_pool.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/cpu/random_cpu_kernel.h"

@@ -41,11 +42,23 @@ void LaunchStandardNormal(int seed, int seed2, const std::vector<AddressPtr> &ou
auto output = reinterpret_cast<float *>(outputs[0]->addr);
size_t lens = outputs[0]->size / sizeof(float);
std::normal_distribution<float> distribution;
auto task = [&](size_t start, size_t end) {
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0;
size_t thread_num = lens < block_size * max_thread_num ? std::ceil(lens / block_size) : max_thread_num;
std::vector<common::Task> tasks;
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
std::default_random_engine random_generator(++RNG_seed);
StandardNormal(output, distribution, random_generator, start, end);
};
CPUKernelUtils::ParallelFor(task, lens);
auto block = [&, start, end]() {
StandardNormal(output, distribution, random_generator, start, end);
return common::SUCCESS;
};
tasks.emplace_back(block);
start += once_compute_size;
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}

void RandomCPUKernel::InitKernel(const CNodePtr &kernel_node) {


Loading…
Cancel
Save