Browse Source

fix cpu op Equal bug when dtype is int8/unit8 etc.

pull/15222/head
CaoJian 4 years ago
parent
commit
48ff37cd7f
1 changed files with 33 additions and 18 deletions
  1. +33
    -18
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc

+ 33
- 18
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc View File

@@ -76,15 +76,16 @@ void ArithmeticCPUKernel::RealDiv(const T *input1, const T *input2, T *out, size
GenIndex(i, &idx); GenIndex(i, &idx);
auto dividend = input1[idx[0]]; auto dividend = input1[idx[0]];
auto divisor = input2[idx[1]]; auto divisor = input2[idx[1]];
if (divisor == 0) {
if (dividend == 0) {
auto zero = (T)0;
if (divisor == zero) {
if (dividend == zero) {
out[i] = std::numeric_limits<T>::quiet_NaN(); out[i] = std::numeric_limits<T>::quiet_NaN();
continue; continue;
} }
if (std::numeric_limits<T>::has_infinity) { if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
out[i] = dividend > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else { } else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
out[i] = dividend > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
} }
continue; continue;
} }
@@ -102,15 +103,16 @@ void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t s
GenIndex(i, &idx); GenIndex(i, &idx);
auto dividend = input1[idx[0]]; auto dividend = input1[idx[0]];
auto divisor = input2[idx[1]]; auto divisor = input2[idx[1]];
if (divisor == 0) {
if (dividend == 0) {
auto zero = (T)0;
if (divisor == zero) {
if (dividend == zero) {
out[i] = std::numeric_limits<T>::quiet_NaN(); out[i] = std::numeric_limits<T>::quiet_NaN();
continue; continue;
} }
if (std::numeric_limits<T>::has_infinity) { if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
out[i] = dividend > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else { } else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
out[i] = dividend > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
} }
continue; continue;
} }
@@ -128,19 +130,20 @@ void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out, siz
GenIndex(i, &idx); GenIndex(i, &idx);
auto dividend = input1[idx[0]]; auto dividend = input1[idx[0]];
auto divisor = input2[idx[1]]; auto divisor = input2[idx[1]];
if (divisor == 0) {
if (dividend == 0) {
auto zero = (T)0;
if (divisor == zero) {
if (dividend == zero) {
out[i] = std::numeric_limits<T>::quiet_NaN(); out[i] = std::numeric_limits<T>::quiet_NaN();
continue; continue;
} }
if (std::numeric_limits<T>::has_infinity) { if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
out[i] = dividend > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else { } else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
out[i] = dividend > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
} }
continue; continue;
} }
out[i] = floor(dividend / divisor);
out[i] = (T)floor(static_cast<double>(dividend) / static_cast<double>(divisor));
} }
}; };
CPUKernelUtils::ParallelFor(task, size); CPUKernelUtils::ParallelFor(task, size);
@@ -295,7 +298,7 @@ void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t
for (size_t i = start; i < end; i++) { for (size_t i = start; i < end; i++) {
std::vector<size_t> idx; std::vector<size_t> idx;
GenIndex(i, &idx); GenIndex(i, &idx);
out[i] = atan2(input1[idx[0]], input2[idx[1]]);
out[i] = (T)atan2(static_cast<double>(input1[idx[0]]), static_cast<double>(input2[idx[1]]));
} }
}; };
CPUKernelUtils::ParallelFor(task, size); CPUKernelUtils::ParallelFor(task, size);
@@ -348,8 +351,8 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CPUKernelUtils::GetElementNumEveryDim(input_shape0_, &input_element_num0_); CPUKernelUtils::GetElementNumEveryDim(input_shape0_, &input_element_num0_);
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_); CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) {
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
if (dtype_ != AnfAlgo::GetInputDeviceDataType(kernel_node, 1)) {
MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type";
} }
target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
@@ -358,14 +361,26 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
bool ArithmeticCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool ArithmeticCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/, const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt8) {
if (dtype_ == kNumberTypeInt32) {
LaunchKernel<int>(inputs, outputs); LaunchKernel<int>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) {
} else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs); LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt64) { } else if (dtype_ == kNumberTypeInt64) {
LaunchKernel<int64_t>(inputs, outputs); LaunchKernel<int64_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeBool) { } else if (dtype_ == kNumberTypeBool) {
LaunchKernelLogic<bool>(inputs, outputs); LaunchKernelLogic<bool>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt8) {
LaunchKernel<int8_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt16) {
LaunchKernel<int16_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs);
} else if (dtype_ == kNumberTypeUInt8) {
LaunchKernel<uint8_t>(inputs, outputs);
} else if (dtype_ == kNumberTypeUInt32) {
LaunchKernel<uint32_t>(inputs, outputs);
} else { } else {
MS_LOG(EXCEPTION) << "Data type " << TypeIdLabel(dtype_) << "is not support."; MS_LOG(EXCEPTION) << "Data type " << TypeIdLabel(dtype_) << "is not support.";
} }


Loading…
Cancel
Save