Browse Source

!15770 fix codex for master

From: @yuanwei66
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @wuxuejian
pull/15770/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
068bbe5331
3 changed files with 9 additions and 30 deletions
  1. +2
    -3
      mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc
  2. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h
  3. +6
    -26
      mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc

+ 2
- 3
mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc View File

@@ -32,8 +32,7 @@ void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
}

template <typename T>
bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
auto node_ = node_wpt_.lock();
if (!node_) {
@@ -71,7 +70,7 @@ bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
}

template <typename T>
void ConcatCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
void ConcatCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) const {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ConcatCPUKernel needs 1 output.";


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h View File

@@ -34,7 +34,7 @@ class ConcatCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;

private:
void CheckParam(const CNodePtr &kernel_node);
void CheckParam(const CNodePtr &kernel_node) const;
int axis_ = 0;
CNodeWeakPtr node_wpt_;
};


+ 6
- 26
mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc View File

@@ -26,6 +26,7 @@ void StandardNormal(float *output, std::normal_distribution<float> distribution,
output[i] = distribution(random_generator);
}
}

void LaunchStandardNormal(int seed, int seed2, const std::vector<AddressPtr> &outputs) {
unsigned int RNG_seed;
std::random_device rd;
@@ -38,33 +39,13 @@ void LaunchStandardNormal(int seed, int seed2, const std::vector<AddressPtr> &ou
}

auto output = reinterpret_cast<float *>(outputs[0]->addr);
// multithreading
size_t lens = outputs[0]->size / sizeof(float);
auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return;
}
std::vector<std::thread> threads;
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return;
}
std::normal_distribution<float> distribution;
while (start < lens) {
// avoid different threads using the same seed to generate the same random number
auto task = [&](size_t start, size_t end) {
std::default_random_engine random_generator(++RNG_seed);
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
threads.emplace_back(std::thread(StandardNormal, output, distribution, random_generator, start, end));
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
StandardNormal(output, distribution, random_generator, start, end);
};
CPUKernelUtils::ParallelFor(task, lens);
}

void RandomCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@@ -91,8 +72,7 @@ void RandomCPUKernel::InitKernel(const CNodePtr &kernel_node) {
seed2_ = LongToInt(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")));
}

bool RandomCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
bool RandomCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
switch (random_op_type_) {
case RANDOM_OP_NORMAL: {


Loading…
Cancel
Save