|
|
|
@@ -96,8 +96,8 @@ void RandomChoiceWithMaskCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
<< "-D, but RandomChoiceWithMask supports only 1-D to 5-D inputs."; |
|
|
|
} |
|
|
|
|
|
|
|
seed_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed")); |
|
|
|
seed2_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed2")); |
|
|
|
seed_ = static_cast<size_t>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed")); |
|
|
|
seed2_ = static_cast<size_t>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "seed2")); |
|
|
|
count_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "count")); |
|
|
|
|
|
|
|
MS_LOG(INFO) << "This op attr count is " << count_; |
|
|
|
@@ -137,7 +137,7 @@ bool RandomChoiceWithMaskCPUKernel::Launch(const std::vector<kernel::AddressPtr> |
|
|
|
auto *output_coordinate = reinterpret_cast<int32_t *>(outputs[0]->addr); |
|
|
|
auto *mask = reinterpret_cast<bool *>(outputs[1]->addr); |
|
|
|
|
|
|
|
int seedc = seed2_ != 0 ? seed2_ : (seed_ != 0 ? seed_ : SizeToInt(generator_())); |
|
|
|
size_t seedc = seed2_ != 0 ? seed2_ : (seed_ != 0 ? seed_ : generator_()); |
|
|
|
for (int32_t i = 0; i < input_total_count; i++) { |
|
|
|
if (input[i] != 0) { |
|
|
|
input_dim[non_zero_num] = i; |
|
|
|
|