From 48ff37cd7f7c6c2366218e8cb24fbd3bcc2f4b31 Mon Sep 17 00:00:00 2001 From: CaoJian Date: Thu, 15 Apr 2021 15:59:47 +0800 Subject: [PATCH] fix cpu op Equal bug when dtype is int8/unit8 etc. --- .../cpu/arithmetic_cpu_kernel.cc | 51 ++++++++++++------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc index 262fa0024f..6a5cd5bdf5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -76,15 +76,16 @@ void ArithmeticCPUKernel::RealDiv(const T *input1, const T *input2, T *out, size GenIndex(i, &idx); auto dividend = input1[idx[0]]; auto divisor = input2[idx[1]]; - if (divisor == 0) { - if (dividend == 0) { + auto zero = (T)0; + if (divisor == zero) { + if (dividend == zero) { out[i] = std::numeric_limits::quiet_NaN(); continue; } if (std::numeric_limits::has_infinity) { - out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + out[i] = dividend > zero ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); } else { - out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); + out[i] = dividend > zero ? std::numeric_limits::max() : std::numeric_limits::min(); } continue; } @@ -102,15 +103,16 @@ void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t s GenIndex(i, &idx); auto dividend = input1[idx[0]]; auto divisor = input2[idx[1]]; - if (divisor == 0) { - if (dividend == 0) { + auto zero = (T)0; + if (divisor == zero) { + if (dividend == zero) { out[i] = std::numeric_limits::quiet_NaN(); continue; } if (std::numeric_limits::has_infinity) { - out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + out[i] = dividend > zero ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); } else { - out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); + out[i] = dividend > zero ? std::numeric_limits::max() : std::numeric_limits::min(); } continue; } @@ -128,19 +130,20 @@ void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out, siz GenIndex(i, &idx); auto dividend = input1[idx[0]]; auto divisor = input2[idx[1]]; - if (divisor == 0) { - if (dividend == 0) { + auto zero = (T)0; + if (divisor == zero) { + if (dividend == zero) { out[i] = std::numeric_limits::quiet_NaN(); continue; } if (std::numeric_limits::has_infinity) { - out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + out[i] = dividend > zero ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); } else { - out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); + out[i] = dividend > zero ? std::numeric_limits::max() : std::numeric_limits::min(); } continue; } - out[i] = floor(dividend / divisor); + out[i] = (T)floor(static_cast(dividend) / static_cast(divisor)); } }; CPUKernelUtils::ParallelFor(task, size); @@ -295,7 +298,7 @@ void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t for (size_t i = start; i < end; i++) { std::vector idx; GenIndex(i, &idx); - out[i] = atan2(input1[idx[0]], input2[idx[1]]); + out[i] = (T)atan2(static_cast(input1[idx[0]]), static_cast(input2[idx[1]])); } }; CPUKernelUtils::ParallelFor(task, size); @@ -348,8 +351,8 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { CPUKernelUtils::GetElementNumEveryDim(input_shape0_, &input_element_num0_); CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_); CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); - dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); - if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) { + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + if (dtype_ != AnfAlgo::GetInputDeviceDataType(kernel_node, 1)) { MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; } target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); @@ -358,14 +361,26 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { bool ArithmeticCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs) { - if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt8) { + if (dtype_ == kNumberTypeInt32) { LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) { + } else if (dtype_ == kNumberTypeFloat32) { LaunchKernel(inputs, outputs); } else if (dtype_ == kNumberTypeInt64) { LaunchKernel(inputs, outputs); } else if (dtype_ == kNumberTypeBool) { LaunchKernelLogic(inputs, outputs); + } else if (dtype_ == kNumberTypeInt8) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat64) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeUInt8) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeUInt32) { + LaunchKernel(inputs, outputs); } else { MS_LOG(EXCEPTION) << "Data type " << TypeIdLabel(dtype_) << "is not support."; }