| @@ -76,15 +76,16 @@ void ArithmeticCPUKernel::RealDiv(const T *input1, const T *input2, T *out, size | |||||
| GenIndex(i, &idx); | GenIndex(i, &idx); | ||||
| auto dividend = input1[idx[0]]; | auto dividend = input1[idx[0]]; | ||||
| auto divisor = input2[idx[1]]; | auto divisor = input2[idx[1]]; | ||||
| if (divisor == 0) { | |||||
| if (dividend == 0) { | |||||
| auto zero = (T)0; | |||||
| if (divisor == zero) { | |||||
| if (dividend == zero) { | |||||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | out[i] = std::numeric_limits<T>::quiet_NaN(); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (std::numeric_limits<T>::has_infinity) { | if (std::numeric_limits<T>::has_infinity) { | ||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||||
| out[i] = dividend > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||||
| } else { | } else { | ||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||||
| out[i] = dividend > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -102,15 +103,16 @@ void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t s | |||||
| GenIndex(i, &idx); | GenIndex(i, &idx); | ||||
| auto dividend = input1[idx[0]]; | auto dividend = input1[idx[0]]; | ||||
| auto divisor = input2[idx[1]]; | auto divisor = input2[idx[1]]; | ||||
| if (divisor == 0) { | |||||
| if (dividend == 0) { | |||||
| auto zero = (T)0; | |||||
| if (divisor == zero) { | |||||
| if (dividend == zero) { | |||||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | out[i] = std::numeric_limits<T>::quiet_NaN(); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (std::numeric_limits<T>::has_infinity) { | if (std::numeric_limits<T>::has_infinity) { | ||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||||
| out[i] = dividend > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||||
| } else { | } else { | ||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||||
| out[i] = dividend > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -128,19 +130,20 @@ void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out, siz | |||||
| GenIndex(i, &idx); | GenIndex(i, &idx); | ||||
| auto dividend = input1[idx[0]]; | auto dividend = input1[idx[0]]; | ||||
| auto divisor = input2[idx[1]]; | auto divisor = input2[idx[1]]; | ||||
| if (divisor == 0) { | |||||
| if (dividend == 0) { | |||||
| auto zero = (T)0; | |||||
| if (divisor == zero) { | |||||
| if (dividend == zero) { | |||||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | out[i] = std::numeric_limits<T>::quiet_NaN(); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (std::numeric_limits<T>::has_infinity) { | if (std::numeric_limits<T>::has_infinity) { | ||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||||
| out[i] = dividend > zero ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||||
| } else { | } else { | ||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||||
| out[i] = dividend > zero ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| out[i] = floor(dividend / divisor); | |||||
| out[i] = (T)floor(static_cast<double>(dividend) / static_cast<double>(divisor)); | |||||
| } | } | ||||
| }; | }; | ||||
| CPUKernelUtils::ParallelFor(task, size); | CPUKernelUtils::ParallelFor(task, size); | ||||
| @@ -295,7 +298,7 @@ void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t | |||||
| for (size_t i = start; i < end; i++) { | for (size_t i = start; i < end; i++) { | ||||
| std::vector<size_t> idx; | std::vector<size_t> idx; | ||||
| GenIndex(i, &idx); | GenIndex(i, &idx); | ||||
| out[i] = atan2(input1[idx[0]], input2[idx[1]]); | |||||
| out[i] = (T)atan2(static_cast<double>(input1[idx[0]]), static_cast<double>(input2[idx[1]])); | |||||
| } | } | ||||
| }; | }; | ||||
| CPUKernelUtils::ParallelFor(task, size); | CPUKernelUtils::ParallelFor(task, size); | ||||
| @@ -348,8 +351,8 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| CPUKernelUtils::GetElementNumEveryDim(input_shape0_, &input_element_num0_); | CPUKernelUtils::GetElementNumEveryDim(input_shape0_, &input_element_num0_); | ||||
| CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_); | CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_); | ||||
| CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); | CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); | ||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||||
| if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) { | |||||
| dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||||
| if (dtype_ != AnfAlgo::GetInputDeviceDataType(kernel_node, 1)) { | |||||
| MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; | MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; | ||||
| } | } | ||||
| target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); | target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); | ||||
| @@ -358,14 +361,26 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| bool ArithmeticCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | bool ArithmeticCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | const std::vector<kernel::AddressPtr> & /*workspace*/, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16 || dtype_ == kNumberTypeInt8) { | |||||
| if (dtype_ == kNumberTypeInt32) { | |||||
| LaunchKernel<int>(inputs, outputs); | LaunchKernel<int>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) { | |||||
| } else if (dtype_ == kNumberTypeFloat32) { | |||||
| LaunchKernel<float>(inputs, outputs); | LaunchKernel<float>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeInt64) { | } else if (dtype_ == kNumberTypeInt64) { | ||||
| LaunchKernel<int64_t>(inputs, outputs); | LaunchKernel<int64_t>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeBool) { | } else if (dtype_ == kNumberTypeBool) { | ||||
| LaunchKernelLogic<bool>(inputs, outputs); | LaunchKernelLogic<bool>(inputs, outputs); | ||||
| } else if (dtype_ == kNumberTypeInt8) { | |||||
| LaunchKernel<int8_t>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeInt16) { | |||||
| LaunchKernel<int16_t>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeFloat16) { | |||||
| LaunchKernel<float16>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeFloat64) { | |||||
| LaunchKernel<double>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeUInt8) { | |||||
| LaunchKernel<uint8_t>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeUInt32) { | |||||
| LaunchKernel<uint32_t>(inputs, outputs); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Data type " << TypeIdLabel(dtype_) << "is not support."; | MS_LOG(EXCEPTION) << "Data type " << TypeIdLabel(dtype_) << "is not support."; | ||||
| } | } | ||||