Browse Source

fix multinomial cpu kernel

tags/v1.5.0-rc1
cristoval 4 years ago
parent
commit
0aa8150fdd
2 changed files with 7 additions and 7 deletions
  1. +7
    -7
      mindspore/ccsrc/backend/kernel_compiler/cpu/multinomial_cpu_kernel.cc
  2. +0
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/multinomial_cpu_kernel.h

mindspore/ccsrc/backend/kernel_compiler/cpu/multinomial_gpu_kernel.cc → mindspore/ccsrc/backend/kernel_compiler/cpu/multinomial_cpu_kernel.cc View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "backend/kernel_compiler/cpu/multinomial_gpu_kernel.h"
#include "backend/kernel_compiler/cpu/multinomial_cpu_kernel.h"
#include <algorithm>
#include <random>
#include "runtime/device/cpu/cpu_device_address.h"
@@ -80,21 +80,21 @@ bool MultinomialGpuKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
// Normalize the cumulative array.
float sum = cumulative_value[num_col - 1];
if (sum != 0) {
for (int k = 1; k < num_col; ++k) {
for (int k = 0; k < num_col; ++k) {
cumulative_value[k] /= sum;
}
}

// Initialize random generator.
std::uniform_real_distribution<float> dist(0.0, 1.0);
int RNG_seed = 0;
int64_t RNG_seed = 0;
if (seed2_ > 0) {
RNG_seed = seed2_;
} else if (seed_ > 0) {
RNG_seed = seed_;
} else {
std::random_device rd;
RNG_seed = static_cast<int>(rd());
RNG_seed = static_cast<int64_t>(rd());
}
std::default_random_engine rng{RNG_seed};

@@ -102,12 +102,12 @@ bool MultinomialGpuKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
for (int n = 0; n < num_sample; ++n) {
auto rand_prob = dist(rng);
int begin = 0;
int end = num_col;
int end = num_col - 1;

while (end - begin > 0) {
int pivot = begin + (end - begin) / 2;
rand_prob = cumulative_value[i * num_col + pivot];
if (pivot > rand_prob) {
float pivot_prob = cumulative_value[i * num_col + pivot];
if (pivot_prob > rand_prob) {
end = pivot;
} else {
begin = pivot + 1;

mindspore/ccsrc/backend/kernel_compiler/cpu/multinomial_gpu_kernel.h → mindspore/ccsrc/backend/kernel_compiler/cpu/multinomial_cpu_kernel.h View File


Loading…
Cancel
Save