diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index aa15ac86e8..51db2c8152 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -61,6 +61,13 @@ int ArithmeticCPUKernel::InitBroadCastCase() { return RET_OK; } + if ((arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) && + (arithmetic_opt_run_ != nullptr && arithmetic_opt_run_int_ != nullptr)) { + /* run opt function + * one of input is scalar */ + return RET_OK; + } + FreeTmpPtr(); CalcMultiplesAndStrides(arithmeticParameter_); @@ -216,7 +223,7 @@ int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, } bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1 - if (input0_broadcast_ == true || input1_broadcast_ == true) { + if (input0_broadcast_ || input1_broadcast_) { return false; } if (arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->in_elements_num1_ ||