|
|
|
@@ -21,6 +21,7 @@ |
|
|
|
#include "schema/model_generated.h" |
|
|
|
#include "src/kernel_registry.h" |
|
|
|
#include "src/runtime/runtime_api.h" |
|
|
|
#include "src/ops/populate/populate_register.h" |
|
|
|
#include "include/errorcode.h" |
|
|
|
|
|
|
|
using mindspore::kernel::KERNEL_ARCH::kCPU; |
|
|
|
@@ -97,6 +98,31 @@ int ArithmeticFP16CPUKernel::Init() { |
|
|
|
return ReSize(); |
|
|
|
} |
|
|
|
|
|
|
|
int ArithmeticFP16CPUKernel::PreProcess() { |
|
|
|
if (!InferShapeDone()) { |
|
|
|
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(true); |
|
|
|
auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_); |
|
|
|
if (ret != 0) { |
|
|
|
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(false); |
|
|
|
MS_LOG(ERROR) << "InferShape fail!"; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
param_ = reinterpret_cast<ArithmeticParameter *>(PopulateArithmetic(primitive_)); |
|
|
|
ret = ReSize(); |
|
|
|
if (ret != 0) { |
|
|
|
MS_LOG(ERROR) << "ReSize fail!ret: " << ret; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
auto outputs = this->out_tensors(); |
|
|
|
for (auto *output : outputs) { |
|
|
|
MS_ASSERT(output != nullptr); |
|
|
|
output->MallocData(); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int ArithmeticFP16CPUKernel::ReSize() { |
|
|
|
param_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); |
|
|
|
param_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); |
|
|
|
|