From 49517fa756c4c70a8a3fd1c9005cdc8842124ddc Mon Sep 17 00:00:00 2001 From: zengxianglong Date: Tue, 22 Dec 2020 09:59:02 +0800 Subject: [PATCH] remove redundant preprocess in arithmetic operator and fix some fp16 bugs --- .../kernel/arm/fp16/arithmetic_fp16.cc | 52 ++++++------------- .../runtime/kernel/arm/fp16/arithmetic_fp16.h | 2 +- .../kernel/arm/fp32/arithmetic_fp32.cc | 36 +------------ .../runtime/kernel/arm/fp32/arithmetic_fp32.h | 1 - mindspore/lite/test/models_onnx.cfg | 2 +- mindspore/lite/test/models_onnx_fp16.cfg | 2 +- mindspore/lite/tools/benchmark/benchmark.cc | 10 ++-- 7 files changed, 26 insertions(+), 79 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index de05d3e105..f1498ff47b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -21,8 +21,8 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" -#include "src/ops/populate/arithmetic_populate.h" #include "include/errorcode.h" +#include "src/ops/arithmetic.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -100,44 +100,26 @@ int ArithmeticFP16CPUKernel::Init() { return ReSize(); } -int ArithmeticFP16CPUKernel::PreProcess() { - if (!InferShapeDone()) { - (const_cast(primitive_))->set_infer_flag(true); - auto ret = (const_cast(primitive_))->InferShape(in_tensors_, out_tensors_); - if (ret != 0) { - (const_cast(primitive_))->set_infer_flag(false); - MS_LOG(ERROR) << "InferShape fail!"; - return ret; - } - if (op_parameter_ != nullptr) { - free(op_parameter_); - op_parameter_ = nullptr; - } - op_parameter_ = PopulateArithmetic(primitive_); - if (op_parameter_ == nullptr) { - MS_LOG(ERROR) << "Malloc parameter failed"; - return RET_ERROR; - } - param_ = reinterpret_cast(op_parameter_); - ret = ReSize(); - if (ret != 0) { - MS_LOG(ERROR) << "ReSize fail!ret: " << ret; - return ret; - } - } +void ArithmeticFP16CPUKernel::InitParam() { + auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_; + param_->broadcasting_ = arithmetic_lite_primitive->Broadcasting(); + param_->ndim_ = arithmetic_lite_primitive->NDims(); - auto outputs = this->out_tensors(); - for (auto *output : outputs) { - MS_ASSERT(output != nullptr); - output->MallocData(); - } - return RET_OK; + param_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); + param_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); + param_->out_elements_num_ = out_tensors_[0]->ElementsNum(); + memcpy(param_->in_shape0_, reinterpret_cast(primitive_)->InShape0().data(), + reinterpret_cast(primitive_)->InShape0().size() * sizeof(int)); + memcpy(param_->in_shape1_, reinterpret_cast(primitive_)->InShape1().data(), + reinterpret_cast(primitive_)->InShape1().size() * sizeof(int)); + memcpy(param_->out_shape_, reinterpret_cast(primitive_)->OutputShape().data(), + reinterpret_cast(primitive_)->OutputShape().size() * sizeof(int)); + + return; } int ArithmeticFP16CPUKernel::ReSize() { - param_->in_elements_num0_ = in_tensors_.at(0)->ElementsNum(); - param_->in_elements_num1_ = in_tensors_.at(1)->ElementsNum(); - param_->out_elements_num_ = out_tensors_.at(0)->ElementsNum(); + InitParam(); if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) { param_->broadcasting_ = false; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h index 4eed7e1c14..5e95858747 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h @@ -44,7 +44,6 @@ class ArithmeticFP16CPUKernel : public LiteKernel { ~ArithmeticFP16CPUKernel() = default; int Init() override; - int PreProcess() override; int ReSize() override; int Run() override; int DoArithmetic(int task_id); @@ -52,6 +51,7 @@ class ArithmeticFP16CPUKernel : public LiteKernel { int out_thread_stride); private: + void InitParam(); void FreeTmpBuffer(); int outside_; int break_pos_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index cbab0cfb22..d789cfafd8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -20,7 +20,7 @@ #include "src/kernel_registry.h" #include "src/runtime/kernel/arm/int8/add_int8.h" #include "src/runtime/runtime_api.h" -#include "src/ops/populate/arithmetic_populate.h" +#include "src/ops/arithmetic.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -94,40 +94,6 @@ int ArithmeticCPUKernel::InitBroadCastCase() { return RET_OK; } -int ArithmeticCPUKernel::PreProcess() { - if (!InferShapeDone()) { - (const_cast(primitive_))->set_infer_flag(true); - auto ret = (const_cast(primitive_))->InferShape(in_tensors_, out_tensors_); - if (ret != 0) { - (const_cast(primitive_))->set_infer_flag(false); - MS_LOG(ERROR) << "InferShape fail!"; - return ret; - } - if (op_parameter_ != nullptr) { - free(op_parameter_); - op_parameter_ = nullptr; - } - op_parameter_ = PopulateArithmetic(primitive_); - if (op_parameter_ == nullptr) { - MS_LOG(ERROR) << "Malloc parameter failed"; - return RET_ERROR; - } - arithmeticParameter_ = reinterpret_cast(op_parameter_); - ret = ReSize(); - if (ret != 0) { - MS_LOG(ERROR) << "ReSize fail!ret: " << ret; - return ret; - } - } - - auto outputs = this->out_tensors(); - for (auto *output : outputs) { - MS_ASSERT(output != nullptr); - output->MallocData(); - } - return RET_OK; -} - void ArithmeticCPUKernel::InitRunFunction() { switch (op_parameter_->type_) { case PrimitiveType_Mul: diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h index 22f474ba1d..78bedefdf3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h @@ -62,7 +62,6 @@ class ArithmeticCPUKernel : public LiteKernel { ~ArithmeticCPUKernel() override; int Init() override; - int PreProcess() override; int ReSize() override; int Run() override; virtual int DoArithmetic(int task_id); diff --git a/mindspore/lite/test/models_onnx.cfg b/mindspore/lite/test/models_onnx.cfg index 9771976a77..039bcf498c 100644 --- a/mindspore/lite/test/models_onnx.cfg +++ b/mindspore/lite/test/models_onnx.cfg @@ -7,7 +7,7 @@ efficientnet-lite4-11.onnx mobilenetv2-7.onnx shufflenet-v2-10.onnx squeezenet1.1-7.onnx -#densenet-9.onnx +densenet-9.onnx ml_table_detection_fp32.onnx ml_table_segment.onnx googlenet-9.onnx diff --git a/mindspore/lite/test/models_onnx_fp16.cfg b/mindspore/lite/test/models_onnx_fp16.cfg index d05f5548ea..c71a920f40 100644 --- a/mindspore/lite/test/models_onnx_fp16.cfg +++ b/mindspore/lite/test/models_onnx_fp16.cfg @@ -7,7 +7,7 @@ efficientnet-lite4-11.onnx 2 mobilenetv2-7.onnx 8 shufflenet-v2-10.onnx 5 squeezenet1.1-7.onnx 1 -#densenet-9.onnx 6 +densenet-9.onnx 6 ml_table_detection_fp32.onnx 2 ml_table_segment.onnx 2 googlenet-9.onnx 3 diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc index 2eaf256ac8..018760ea18 100644 --- a/mindspore/lite/tools/benchmark/benchmark.cc +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -38,20 +38,20 @@ int Benchmark::GenerateRandomData(size_t size, void *data, TypeId data_type) { switch (data_type) { case kNumberTypeFloat32: case kNumberTypeFloat: - FillInputData(size, data, std::uniform_real_distribution(-0.5f, 0.5f)); + FillInputData(size, data, std::uniform_real_distribution(0.1f, 1.0f)); break; case kNumberTypeFloat64: - FillInputData(size, data, std::uniform_real_distribution(-0.5, 0.5)); + FillInputData(size, data, std::uniform_real_distribution(0.1, 1.0)); break; case kNumberTypeInt64: - FillInputData(size, data, std::uniform_int_distribution(0, 99)); + FillInputData(size, data, std::uniform_int_distribution(0, 1)); break; case kNumberTypeInt: case kNumberTypeInt32: - FillInputData(size, data, std::uniform_int_distribution(0, 99)); + FillInputData(size, data, std::uniform_int_distribution(0, 1)); break; case kNumberTypeInt16: - FillInputData(size, data, std::uniform_int_distribution(0, 99)); + FillInputData(size, data, std::uniform_int_distribution(0, 1)); break; case kNumberTypeInt8: FillInputData(size, data, std::uniform_int_distribution(-127, 127));