Browse Source

!15527 fix CPU topk output unstable

From: @zhao_ting_v
Reviewed-by: @wuxuejian,@liangchenghui
Signed-off-by: @wuxuejian
pull/15527/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
b76c0599d0
3 changed files with 8 additions and 4 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc
  2. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc
  3. +3
    -3
      mindspore/ccsrc/backend/kernel_compiler/cpu/topk_cpu_kernel.cc

+ 3
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc View File

@@ -267,8 +267,10 @@ bool ArithmeticSelfCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inpu
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) { if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) {
LaunchKernel<float>(inputs, outputs); LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt64) {
} else if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) {
LaunchKernel<int>(inputs, outputs); LaunchKernel<int>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt64) {
LaunchKernel<int64_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeBool) { } else if (dtype_ == kNumberTypeBool) {
LaunchKernelLogic<bool>(inputs, outputs); LaunchKernelLogic<bool>(inputs, outputs);
} else { } else {


+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc View File

@@ -29,6 +29,7 @@ void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index); std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index);
size_t tensor_size = size_t tensor_size =
shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
tensor_size = std::max(tensor_size, type_size);
input_size_list_.emplace_back(tensor_size); input_size_list_.emplace_back(tensor_size);
} }
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
@@ -38,6 +39,7 @@ void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index); std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index);
size_t tensor_size = size_t tensor_size =
shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
tensor_size = std::max(tensor_size, type_size);
output_size_list_.emplace_back(tensor_size); output_size_list_.emplace_back(tensor_size);
} }
} }


+ 3
- 3
mindspore/ccsrc/backend/kernel_compiler/cpu/topk_cpu_kernel.cc View File

@@ -50,11 +50,11 @@ void TopKCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const st
std::vector<size_t> idx(inner_size_); std::vector<size_t> idx(inner_size_);
auto base_input = i * inner_size_; auto base_input = i * inner_size_;
std::iota(idx.begin(), idx.end(), base_input); std::iota(idx.begin(), idx.end(), base_input);
std::sort(idx.begin(), idx.end(),
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
std::stable_sort(idx.begin(), idx.end(),
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
auto base_output = i * k_num; auto base_output = i * k_num;
if (!sorted_) { if (!sorted_) {
std::sort(idx.begin(), idx.begin() + k_num);
std::stable_sort(idx.begin(), idx.begin() + k_num);
} }
for (int j = 0; j < k_num; ++j) { for (int j = 0; j < k_num; ++j) {
indices[base_output + j] = idx[j] - base_input; indices[base_output + j] = idx[j] - base_input;


Loading…
Cancel
Save