From 58c808d704fee89943893a9ba93b9aa5731b1c6b Mon Sep 17 00:00:00 2001 From: wangzhe Date: Sun, 25 Oct 2020 16:18:17 +0800 Subject: [PATCH] dpp op refactor GetInputData & process int8 anchor data --- .../arm/base/detection_post_process_base.cc | 26 +++++-------------- .../arm/base/detection_post_process_base.h | 2 +- .../kernel/arm/fp32/detection_post_process.cc | 11 ++++++++ .../kernel/arm/fp32/detection_post_process.h | 3 +++ .../lite/test/models_tflite_awaretraining.cfg | 6 ++--- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.cc index b470ae6e3b..a7783c975f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.cc @@ -30,21 +30,18 @@ int DetectionPostProcessBaseCPUKernel::Init() { auto anchor_tensor = in_tensors_.at(2); DetectionPostProcessParameter *parameter = reinterpret_cast(op_parameter_); parameter->anchors_ = nullptr; - if (anchor_tensor->data_type() == kNumberTypeUInt8) { - const auto quant_params = anchor_tensor->GetQuantParams(); - const double scale = quant_params.at(0).scale; - const int32_t zp = quant_params.at(0).zeroPoint; - auto anchor_uint8 = reinterpret_cast(anchor_tensor->MutableData()); + if (anchor_tensor->data_type() == kNumberTypeInt8) { + auto quant_param = anchor_tensor->GetQuantParams().front(); + auto anchor_int8 = reinterpret_cast(anchor_tensor->MutableData()); auto anchor_fp32 = new (std::nothrow) float[anchor_tensor->ElementsNum()]; if (anchor_fp32 == nullptr) { MS_LOG(ERROR) << "Malloc anchor failed"; return RET_ERROR; } - for (int i = 0; i < anchor_tensor->ElementsNum(); ++i) { - *(anchor_fp32 + i) = static_cast((static_cast(anchor_uint8[i]) - zp) * scale); - } + DoDequantizeInt8ToFp32(anchor_int8, anchor_fp32, quant_param.scale, quant_param.zeroPoint, + anchor_tensor->ElementsNum()); parameter->anchors_ = anchor_fp32; - } else if (anchor_tensor->data_type() == kNumberTypeFloat32) { + } else if (anchor_tensor->data_type() == kNumberTypeFloat32 || anchor_tensor->data_type() == kNumberTypeFloat) { parameter->anchors_ = new (std::nothrow) float[anchor_tensor->ElementsNum()]; if (parameter->anchors_ == nullptr) { MS_LOG(ERROR) << "Malloc anchor failed"; @@ -65,17 +62,6 @@ DetectionPostProcessBaseCPUKernel::~DetectionPostProcessBaseCPUKernel() { int DetectionPostProcessBaseCPUKernel::ReSize() { return RET_OK; } -int DetectionPostProcessBaseCPUKernel::GetInputData() { - if ((in_tensors_.at(0)->data_type() != kNumberTypeFloat32 && in_tensors_.at(0)->data_type() != kNumberTypeFloat) || - (in_tensors_.at(1)->data_type() != kNumberTypeFloat32 && in_tensors_.at(1)->data_type() != kNumberTypeFloat)) { - MS_LOG(ERROR) << "Input data type error"; - return RET_ERROR; - } - input_boxes = reinterpret_cast(in_tensors_.at(0)->MutableData()); - input_scores = reinterpret_cast(in_tensors_.at(1)->MutableData()); - return RET_OK; -} - int DetectionPostProcessBaseCPUKernel::Run() { MS_ASSERT(context_->allocator != nullptr); int status = GetInputData(); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.h b/mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.h index c9a09fc356..c48017a8d7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.h @@ -41,7 +41,7 @@ class DetectionPostProcessBaseCPUKernel : public LiteKernel { float *input_boxes; float *input_scores; - int GetInputData(); + virtual int GetInputData() = 0; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_BASE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.cc index f9118a7758..2c550849df 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.cc @@ -27,6 +27,17 @@ using mindspore::schema::PrimitiveType_DetectionPostProcess; namespace mindspore::kernel { +int DetectionPostProcessCPUKernel::GetInputData() { + if ((in_tensors_.at(0)->data_type() != kNumberTypeFloat32 && in_tensors_.at(0)->data_type() != kNumberTypeFloat) || + (in_tensors_.at(1)->data_type() != kNumberTypeFloat32 && in_tensors_.at(1)->data_type() != kNumberTypeFloat)) { + MS_LOG(ERROR) << "Input data type error"; + return RET_ERROR; + } + input_boxes = reinterpret_cast(in_tensors_.at(0)->MutableData()); + input_scores = reinterpret_cast(in_tensors_.at(1)->MutableData()); + return RET_OK; +} + kernel::LiteKernel *CpuDetectionPostProcessFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::InnerContext *ctx, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h b/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h index fe48387353..4afe75500c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h @@ -33,6 +33,9 @@ class DetectionPostProcessCPUKernel : public DetectionPostProcessBaseCPUKernel { const mindspore::lite::PrimitiveC *primitive) : DetectionPostProcessBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} ~DetectionPostProcessCPUKernel() = default; + + private: + int GetInputData(); }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_H_ diff --git a/mindspore/lite/test/models_tflite_awaretraining.cfg b/mindspore/lite/test/models_tflite_awaretraining.cfg index dac3fc8b41..b4645eeff7 100644 --- a/mindspore/lite/test/models_tflite_awaretraining.cfg +++ b/mindspore/lite/test/models_tflite_awaretraining.cfg @@ -33,6 +33,6 @@ lite-model_on_device_vision_classifier_landmarks_classifier_oceania_antarctica_V lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite vision_classifier_fungi_mobile_V1_1_default_1.tflite -#detect.tflite -#ssd_mobilenet_v1_1_default_1.tflite -#object_detection_mobile_object_localizer_v1_1_default_1.tflite +detect.tflite +ssd_mobilenet_v1_1_default_1.tflite +object_detection_mobile_object_localizer_v1_1_default_1.tflite