Browse Source

fix cpu multinomial

tags/v1.6.0
VectorSL 4 years ago
parent
commit
21ea25b0aa
2 changed files with 13 additions and 12 deletions
  1. +11
    -12
      mindspore/ccsrc/backend/kernel_compiler/cpu/multinomial_cpu_kernel.cc
  2. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/multinomial_cpu_kernel.h

+ 11
- 12
mindspore/ccsrc/backend/kernel_compiler/cpu/multinomial_cpu_kernel.cc View File

@@ -16,7 +16,6 @@

#include "backend/kernel_compiler/cpu/multinomial_cpu_kernel.h"
#include <algorithm>
#include <random>
#include "runtime/device/cpu/cpu_device_address.h"

namespace mindspore {
@@ -35,6 +34,16 @@ void MultinomialCpuKernel::InitKernel(const CNodePtr &kernel_node) {

seed_ = static_cast<int>(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")));
seed2_ = static_cast<int>(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")));
int64_t RNG_seed = 0;
if (seed2_ > 0) {
RNG_seed = seed2_;
} else if (seed_ > 0) {
RNG_seed = seed_;
} else {
std::random_device rd;
RNG_seed = static_cast<int64_t>(rd());
}
rng_.seed(RNG_seed);
}

bool MultinomialCpuKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
@@ -91,20 +100,10 @@ bool MultinomialCpuKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,

// Initialize random generator.
std::uniform_real_distribution<float> dist(0.0, 1.0);
int64_t RNG_seed = 0;
if (seed2_ > 0) {
RNG_seed = seed2_;
} else if (seed_ > 0) {
RNG_seed = seed_;
} else {
std::random_device rd;
RNG_seed = static_cast<int64_t>(rd());
}
std::default_random_engine rng{RNG_seed};

// Sample data from cumulative array.
for (int n = 0; n < num_sample; ++n) {
auto rand_prob = dist(rng);
auto rand_prob = dist(rng_);
int begin = 0;
int end = num_col - 1;



+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/multinomial_cpu_kernel.h View File

@@ -19,6 +19,7 @@
#include <memory>
#include <unordered_map>
#include <vector>
#include <random>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "nnacl/base/tile_base.h"
@@ -39,6 +40,7 @@ class MultinomialCpuKernel : public CPUKernel {
std::vector<size_t> input_shape_;
int seed_{0};
int seed2_{0};
std::default_random_engine rng_;
};

MS_REG_CPU_KERNEL(


Loading…
Cancel
Save