Browse Source

dpp anchor process uint8 data in order to support old model

tags/v1.1.0
wangzhe 5 years ago
parent
commit
4c2d8a3181
1 changed files with 15 additions and 2 deletions
  1. +15
    -2
      mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.cc

+ 15
- 2
mindspore/lite/src/runtime/kernel/arm/base/detection_post_process_base.cc View File

@@ -49,7 +49,8 @@ int DetectionPostProcessBaseCPUKernel::Init() {
auto anchor_tensor = in_tensors_.at(2);
if (anchor_tensor->data_type() == kNumberTypeInt8) {
auto quant_param = anchor_tensor->GetQuantParams().front();
auto anchor_int8 = reinterpret_cast<int8_t *>(anchor_tensor->MutableData());
auto anchor_int8 = reinterpret_cast<int8_t *>(anchor_tensor->data_c());
MS_ASSERT(anchor_int8 != nullptr);
auto anchor_fp32 = new (std::nothrow) float[anchor_tensor->ElementsNum()];
if (anchor_fp32 == nullptr) {
MS_LOG(ERROR) << "Malloc anchor failed";
@@ -58,13 +59,25 @@ int DetectionPostProcessBaseCPUKernel::Init() {
DoDequantizeInt8ToFp32(anchor_int8, anchor_fp32, quant_param.scale, quant_param.zeroPoint,
anchor_tensor->ElementsNum());
params_->anchors_ = anchor_fp32;
} else if (anchor_tensor->data_type() == kNumberTypeUInt8) {
auto quant_param = anchor_tensor->GetQuantParams().front();
auto anchor_uint8 = reinterpret_cast<uint8_t *>(anchor_tensor->data_c());
MS_ASSERT(anchor_uint8 != nullptr);
auto anchor_fp32 = new (std::nothrow) float[anchor_tensor->ElementsNum()];
if (anchor_fp32 == nullptr) {
MS_LOG(ERROR) << "Malloc anchor failed";
return RET_ERROR;
}
DoDequantizeUInt8ToFp32(anchor_uint8, anchor_fp32, quant_param.scale, quant_param.zeroPoint,
anchor_tensor->ElementsNum());
params_->anchors_ = anchor_fp32;
} else if (anchor_tensor->data_type() == kNumberTypeFloat32 || anchor_tensor->data_type() == kNumberTypeFloat) {
params_->anchors_ = new (std::nothrow) float[anchor_tensor->ElementsNum()];
if (params_->anchors_ == nullptr) {
MS_LOG(ERROR) << "Malloc anchor failed";
return RET_ERROR;
}
memcpy(params_->anchors_, anchor_tensor->MutableData(), anchor_tensor->Size());
memcpy(params_->anchors_, anchor_tensor->data_c(), anchor_tensor->Size());
} else {
MS_LOG(ERROR) << "unsupported anchor data type " << anchor_tensor->data_type();
return RET_ERROR;


Loading…
Cancel
Save