|
|
|
@@ -49,7 +49,8 @@ int DetectionPostProcessBaseCPUKernel::Init() { |
|
|
|
auto anchor_tensor = in_tensors_.at(2); |
|
|
|
if (anchor_tensor->data_type() == kNumberTypeInt8) { |
|
|
|
auto quant_param = anchor_tensor->GetQuantParams().front(); |
|
|
|
auto anchor_int8 = reinterpret_cast<int8_t *>(anchor_tensor->MutableData()); |
|
|
|
auto anchor_int8 = reinterpret_cast<int8_t *>(anchor_tensor->data_c()); |
|
|
|
MS_ASSERT(anchor_int8 != nullptr); |
|
|
|
auto anchor_fp32 = new (std::nothrow) float[anchor_tensor->ElementsNum()]; |
|
|
|
if (anchor_fp32 == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Malloc anchor failed"; |
|
|
|
@@ -58,13 +59,25 @@ int DetectionPostProcessBaseCPUKernel::Init() { |
|
|
|
DoDequantizeInt8ToFp32(anchor_int8, anchor_fp32, quant_param.scale, quant_param.zeroPoint, |
|
|
|
anchor_tensor->ElementsNum()); |
|
|
|
params_->anchors_ = anchor_fp32; |
|
|
|
} else if (anchor_tensor->data_type() == kNumberTypeUInt8) { |
|
|
|
auto quant_param = anchor_tensor->GetQuantParams().front(); |
|
|
|
auto anchor_uint8 = reinterpret_cast<uint8_t *>(anchor_tensor->data_c()); |
|
|
|
MS_ASSERT(anchor_uint8 != nullptr); |
|
|
|
auto anchor_fp32 = new (std::nothrow) float[anchor_tensor->ElementsNum()]; |
|
|
|
if (anchor_fp32 == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Malloc anchor failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
DoDequantizeUInt8ToFp32(anchor_uint8, anchor_fp32, quant_param.scale, quant_param.zeroPoint, |
|
|
|
anchor_tensor->ElementsNum()); |
|
|
|
params_->anchors_ = anchor_fp32; |
|
|
|
} else if (anchor_tensor->data_type() == kNumberTypeFloat32 || anchor_tensor->data_type() == kNumberTypeFloat) { |
|
|
|
params_->anchors_ = new (std::nothrow) float[anchor_tensor->ElementsNum()]; |
|
|
|
if (params_->anchors_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Malloc anchor failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
memcpy(params_->anchors_, anchor_tensor->MutableData(), anchor_tensor->Size()); |
|
|
|
memcpy(params_->anchors_, anchor_tensor->data_c(), anchor_tensor->Size()); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "unsupported anchor data type " << anchor_tensor->data_type(); |
|
|
|
return RET_ERROR; |
|
|
|
|