Browse Source

dpp op refactor GetInputData & process int8 anchor data

tags/v1.1.0
wangzhe 5 years ago
parent
commit
58c808d704
5 changed files with 24 additions and 24 deletions
  1. +6
    -20
      mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.cc
  2. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.h
  3. +11
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.cc
  4. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h
  5. +3
    -3
      mindspore/lite/test/models_tflite_awaretraining.cfg

+ 6
- 20
mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.cc View File

@@ -30,21 +30,18 @@ int DetectionPostProcessBaseCPUKernel::Init() {
auto anchor_tensor = in_tensors_.at(2); auto anchor_tensor = in_tensors_.at(2);
DetectionPostProcessParameter *parameter = reinterpret_cast<DetectionPostProcessParameter *>(op_parameter_); DetectionPostProcessParameter *parameter = reinterpret_cast<DetectionPostProcessParameter *>(op_parameter_);
parameter->anchors_ = nullptr; parameter->anchors_ = nullptr;
if (anchor_tensor->data_type() == kNumberTypeUInt8) {
const auto quant_params = anchor_tensor->GetQuantParams();
const double scale = quant_params.at(0).scale;
const int32_t zp = quant_params.at(0).zeroPoint;
auto anchor_uint8 = reinterpret_cast<uint8_t *>(anchor_tensor->MutableData());
if (anchor_tensor->data_type() == kNumberTypeInt8) {
auto quant_param = anchor_tensor->GetQuantParams().front();
auto anchor_int8 = reinterpret_cast<int8_t *>(anchor_tensor->MutableData());
auto anchor_fp32 = new (std::nothrow) float[anchor_tensor->ElementsNum()]; auto anchor_fp32 = new (std::nothrow) float[anchor_tensor->ElementsNum()];
if (anchor_fp32 == nullptr) { if (anchor_fp32 == nullptr) {
MS_LOG(ERROR) << "Malloc anchor failed"; MS_LOG(ERROR) << "Malloc anchor failed";
return RET_ERROR; return RET_ERROR;
} }
for (int i = 0; i < anchor_tensor->ElementsNum(); ++i) {
*(anchor_fp32 + i) = static_cast<float>((static_cast<int>(anchor_uint8[i]) - zp) * scale);
}
DoDequantizeInt8ToFp32(anchor_int8, anchor_fp32, quant_param.scale, quant_param.zeroPoint,
anchor_tensor->ElementsNum());
parameter->anchors_ = anchor_fp32; parameter->anchors_ = anchor_fp32;
} else if (anchor_tensor->data_type() == kNumberTypeFloat32) {
} else if (anchor_tensor->data_type() == kNumberTypeFloat32 || anchor_tensor->data_type() == kNumberTypeFloat) {
parameter->anchors_ = new (std::nothrow) float[anchor_tensor->ElementsNum()]; parameter->anchors_ = new (std::nothrow) float[anchor_tensor->ElementsNum()];
if (parameter->anchors_ == nullptr) { if (parameter->anchors_ == nullptr) {
MS_LOG(ERROR) << "Malloc anchor failed"; MS_LOG(ERROR) << "Malloc anchor failed";
@@ -65,17 +62,6 @@ DetectionPostProcessBaseCPUKernel::~DetectionPostProcessBaseCPUKernel() {


int DetectionPostProcessBaseCPUKernel::ReSize() { return RET_OK; } int DetectionPostProcessBaseCPUKernel::ReSize() { return RET_OK; }


int DetectionPostProcessBaseCPUKernel::GetInputData() {
if ((in_tensors_.at(0)->data_type() != kNumberTypeFloat32 && in_tensors_.at(0)->data_type() != kNumberTypeFloat) ||
(in_tensors_.at(1)->data_type() != kNumberTypeFloat32 && in_tensors_.at(1)->data_type() != kNumberTypeFloat)) {
MS_LOG(ERROR) << "Input data type error";
return RET_ERROR;
}
input_boxes = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
input_scores = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
return RET_OK;
}

int DetectionPostProcessBaseCPUKernel::Run() { int DetectionPostProcessBaseCPUKernel::Run() {
MS_ASSERT(context_->allocator != nullptr); MS_ASSERT(context_->allocator != nullptr);
int status = GetInputData(); int status = GetInputData();


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.h View File

@@ -41,7 +41,7 @@ class DetectionPostProcessBaseCPUKernel : public LiteKernel {
float *input_boxes; float *input_boxes;
float *input_scores; float *input_scores;


int GetInputData();
virtual int GetInputData() = 0;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_BASE_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_BASE_H_

+ 11
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.cc View File

@@ -27,6 +27,17 @@ using mindspore::schema::PrimitiveType_DetectionPostProcess;


namespace mindspore::kernel { namespace mindspore::kernel {


int DetectionPostProcessCPUKernel::GetInputData() {
if ((in_tensors_.at(0)->data_type() != kNumberTypeFloat32 && in_tensors_.at(0)->data_type() != kNumberTypeFloat) ||
(in_tensors_.at(1)->data_type() != kNumberTypeFloat32 && in_tensors_.at(1)->data_type() != kNumberTypeFloat)) {
MS_LOG(ERROR) << "Input data type error";
return RET_ERROR;
}
input_boxes = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
input_scores = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
return RET_OK;
}

kernel::LiteKernel *CpuDetectionPostProcessFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, kernel::LiteKernel *CpuDetectionPostProcessFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter, const lite::InnerContext *ctx, OpParameter *opParameter, const lite::InnerContext *ctx,


+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h View File

@@ -33,6 +33,9 @@ class DetectionPostProcessCPUKernel : public DetectionPostProcessBaseCPUKernel {
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: DetectionPostProcessBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} : DetectionPostProcessBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~DetectionPostProcessCPUKernel() = default; ~DetectionPostProcessCPUKernel() = default;

private:
int GetInputData();
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_H_

+ 3
- 3
mindspore/lite/test/models_tflite_awaretraining.cfg View File

@@ -33,6 +33,6 @@ lite-model_on_device_vision_classifier_landmarks_classifier_oceania_antarctica_V
lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite
lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite
vision_classifier_fungi_mobile_V1_1_default_1.tflite vision_classifier_fungi_mobile_V1_1_default_1.tflite
#detect.tflite
#ssd_mobilenet_v1_1_default_1.tflite
#object_detection_mobile_object_localizer_v1_1_default_1.tflite
detect.tflite
ssd_mobilenet_v1_1_default_1.tflite
object_detection_mobile_object_localizer_v1_1_default_1.tflite

Loading…
Cancel
Save