Browse Source

fixes RandomChoiceWithMask CPU

tags/v1.6.0
huangbo77 4 years ago
parent
commit
6641c97b6c
2 changed files with 5 additions and 5 deletions
  1. +3
    -3
      mindspore/ccsrc/backend/kernel_compiler/cpu/random_choice_with_mask_cpu_kernel.cc
  2. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/random_choice_with_mask_cpu_kernel.h

+ 3
- 3
mindspore/ccsrc/backend/kernel_compiler/cpu/random_choice_with_mask_cpu_kernel.cc View File

@@ -96,8 +96,8 @@ void RandomChoiceWithMaskCPUKernel::InitKernel(const CNodePtr &kernel_node) {
<< "-D, but RandomChoiceWithMask supports only 1-D to 5-D inputs.";
}

seed_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed"));
seed2_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed2"));
seed_ = static_cast<size_t>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed"));
seed2_ = static_cast<size_t>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed2"));
count_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "count"));

MS_LOG(INFO) << "This op attr count is " << count_;
@@ -137,7 +137,7 @@ bool RandomChoiceWithMaskCPUKernel::Launch(const std::vector<kernel::AddressPtr>
auto *output_coordinate = reinterpret_cast<int32_t *>(outputs[0]->addr);
auto *mask = reinterpret_cast<bool *>(outputs[1]->addr);

int seedc = seed2_ != 0 ? seed2_ : (seed_ != 0 ? seed_ : SizeToInt(generator_()));
size_t seedc = seed2_ != 0 ? seed2_ : (seed_ != 0 ? seed_ : generator_());
for (int32_t i = 0; i < input_total_count; i++) {
if (input[i] != 0) {
input_dim[non_zero_num] = i;


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/random_choice_with_mask_cpu_kernel.h View File

@@ -49,8 +49,8 @@ class RandomChoiceWithMaskCPUKernel : public CPUKernel {
int32_t count_{0};
std::vector<int32_t> dims_;
size_t input_shape_size_{0};
int seed_{0};
int seed2_{0};
size_t seed_{0};
size_t seed2_{0};
int input_size_{1};
std::mt19937 generator_;
};


Loading…
Cancel
Save