| @@ -22,6 +22,7 @@ | |||||
| #include "src/runtime/kernel/arm/int8/add_int8.h" | #include "src/runtime/kernel/arm/int8/add_int8.h" | ||||
| #include "src/runtime/kernel/arm/int8/mul_int8.h" | #include "src/runtime/kernel/arm/int8/mul_int8.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "src/populate_parameter.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -40,6 +41,31 @@ int ArithmeticCPUKernel::Init() { | |||||
| return ReSize(); | return ReSize(); | ||||
| } | } | ||||
| int ArithmeticCPUKernel::PreProcess() { | |||||
| if (!InferShapeDone()) { | |||||
| (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(true); | |||||
| auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_); | |||||
| if (ret != 0) { | |||||
| (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->SetInferFlag(false); | |||||
| MS_LOG(ERROR) << "InferShape fail!"; | |||||
| return ret; | |||||
| } | |||||
| arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(kernel::PopulateArithmetic(primitive_)); | |||||
| ret = ReSize(); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "ReSize fail!ret: " << ret; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| auto outputs = this->out_tensors(); | |||||
| for (auto *output : outputs) { | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->MallocData(); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ArithmeticCPUKernel::ReSize() { | int ArithmeticCPUKernel::ReSize() { | ||||
| if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) { | if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16) { | ||||
| data_type_ = kDataTypeFloat; | data_type_ = kDataTypeFloat; | ||||
| @@ -163,6 +163,7 @@ class ArithmeticCPUKernel : public LiteKernel { | |||||
| ~ArithmeticCPUKernel() override; | ~ArithmeticCPUKernel() override; | ||||
| int Init() override; | int Init() override; | ||||
| int PreProcess() override; | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int DoArithmetic(int task_id); | int DoArithmetic(int task_id); | ||||
| @@ -81,7 +81,7 @@ int QuantizedAddCPUKernel::Run() { | |||||
| input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->MutableData()); | input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->MutableData()); | ||||
| output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->MutableData()); | output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->MutableData()); | ||||
| elements_num_ = in_tensors_.at(0)->ElementsNum(); | |||||
| elements_num_ = out_tensors_.at(0)->ElementsNum(); | |||||
| count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; | count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; | ||||
| if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { | if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { | ||||
| @@ -106,7 +106,7 @@ int MulInt8CPUKernel::Run() { | |||||
| input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->MutableData()); | input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->MutableData()); | ||||
| output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->MutableData()); | output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->MutableData()); | ||||
| elements_num_ = in_tensors_.at(0)->ElementsNum(); | |||||
| elements_num_ = out_tensors_.at(0)->ElementsNum(); | |||||
| count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; | count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; | ||||
| if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { | if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { | ||||
| input0_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size())); | input0_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size())); | ||||
| @@ -87,8 +87,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||||
| attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||||
| attr->dilateH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||||
| attr->dilateW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||||
| } else if (onnx_node_attr.name() == "kernels") { | } else if (onnx_node_attr.name() == "kernels") { | ||||
| if (onnx_node_attr.ints().size() != 2) { | if (onnx_node_attr.ints().size() != 2) { | ||||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| @@ -101,8 +101,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||||
| attr->kernelH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||||
| attr->kernelW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||||
| } else if (onnx_node_attr.name() == "auto_pad") { | } else if (onnx_node_attr.name() == "auto_pad") { | ||||
| attr->padMode = GetOnnxPadMode(onnx_node_attr); | attr->padMode = GetOnnxPadMode(onnx_node_attr); | ||||
| } else if (onnx_node_attr.name() == "pads") { | } else if (onnx_node_attr.name() == "pads") { | ||||
| @@ -119,8 +119,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; | MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||||
| attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||||
| attr->strideH = static_cast<int32_t>(onnx_node_attr.ints(0)); | |||||
| attr->strideW = static_cast<int32_t>(onnx_node_attr.ints(1)); | |||||
| } else if (onnx_node_attr.name() == "order") { | } else if (onnx_node_attr.name() == "order") { | ||||
| if (onnx_node_attr.s() == "NHWC") { | if (onnx_node_attr.s() == "NHWC") { | ||||
| attr->format = schema::Format::Format_NHWC; | attr->format = schema::Format::Format_NHWC; | ||||