diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index 24f8a043f5..0da0c36345 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -21,6 +21,7 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" +#include "src/ops/populate/populate_register.h" #include "include/errorcode.h" using mindspore::kernel::KERNEL_ARCH::kCPU; @@ -97,6 +98,31 @@ int ArithmeticFP16CPUKernel::Init() { return ReSize(); } +int ArithmeticFP16CPUKernel::PreProcess() { + if (!InferShapeDone()) { + (const_cast(primitive_))->SetInferFlag(true); + auto ret = (const_cast(primitive_))->InferShape(in_tensors_, out_tensors_); + if (ret != 0) { + (const_cast(primitive_))->SetInferFlag(false); + MS_LOG(ERROR) << "InferShape fail!"; + return ret; + } + param_ = reinterpret_cast(PopulateArithmetic(primitive_)); + ret = ReSize(); + if (ret != 0) { + MS_LOG(ERROR) << "ReSize fail!ret: " << ret; + return ret; + } + } + + auto outputs = this->out_tensors(); + for (auto *output : outputs) { + MS_ASSERT(output != nullptr); + output->MallocData(); + } + return RET_OK; +} + int ArithmeticFP16CPUKernel::ReSize() { param_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); param_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h index b1dfcc2236..4eed7e1c14 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h @@ -44,6 +44,7 @@ class ArithmeticFP16CPUKernel : public LiteKernel { ~ArithmeticFP16CPUKernel() = default; int Init() override; + int PreProcess() override; int ReSize() override; int Run() override; int DoArithmetic(int task_id);