Browse Source

!13366 [MS][LITE][CPU]fix bug of equal

From: @fuzhiye
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @hangangqiang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
ddccb4a454
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc

+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc View File

@@ -31,7 +31,7 @@ namespace mindspore::kernel {
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
int out_thread_stride) { int out_thread_stride) {
if (dim > break_pos_) { if (dim > break_pos_) {
if (in_tensors_[0]->data_type() == kNumberTypeInt) {
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride, return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride,
reinterpret_cast<int *>(input1) + out_thread_stride, reinterpret_cast<int *>(input1) + out_thread_stride,
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count); reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
@@ -44,7 +44,7 @@ int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *o
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i; int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i; int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i;
int error_code; int error_code;
if (in_tensors_[0]->data_type() == kNumberTypeInt) {
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) {
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * param_->in_strides0_[dim], error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * param_->in_strides0_[dim],
reinterpret_cast<int *>(input1) + pos1_ * param_->in_strides1_[dim], reinterpret_cast<int *>(input1) + pos1_ * param_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count, reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count,


Loading…
Cancel
Save