diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc index 2dc773a3e6..2eb55a76b0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc @@ -31,7 +31,7 @@ namespace mindspore::kernel { int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) { if (dim > break_pos_) { - if (in_tensors_[0]->data_type() == kNumberTypeInt) { + if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) { return func_int32_(reinterpret_cast(input0) + out_thread_stride, reinterpret_cast(input1) + out_thread_stride, reinterpret_cast(output) + out_thread_stride, out_count); @@ -44,7 +44,7 @@ int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *o int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i; int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i; int error_code; - if (in_tensors_[0]->data_type() == kNumberTypeInt) { + if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) { error_code = BroadcastRun(reinterpret_cast(input0) + pos0_ * param_->in_strides0_[dim], reinterpret_cast(input1) + pos1_ * param_->in_strides1_[dim], reinterpret_cast(output) + i * param_->out_strides_[dim], dim + 1, out_count,