diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc index 97b8cc2597..d91a5cd383 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc @@ -46,30 +46,33 @@ int QuantDTypeCastCPUKernel::Init() { MS_LOG(ERROR) << "param data type and tensor data type do not match."; return RET_ERROR; } - inverse_ = false; } else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeFloat32) { if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeFloat32) { MS_LOG(ERROR) << "param data type and tensor data type do not match."; return RET_ERROR; } - inverse_ = true; } else if (param->srcT == kNumberTypeUInt8 && param->dstT == kNumberTypeInt8) { if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeInt8) { MS_LOG(ERROR) << "param data type and tensor data type do not match."; return RET_ERROR; } - inverse_ = false; + } else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeInt8) { + if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeInt8) { + MS_LOG(ERROR) << "param data type and tensor data type do not match."; + return RET_ERROR; + } } else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeUInt8) { if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeUInt8) { MS_LOG(ERROR) << "param data type and tensor data type do not match."; return RET_ERROR; } - inverse_ = true; } else { MS_LOG(ERROR) << "param data type not supported:" << " src: " << param->srcT << " dst: " << param->dstT; return RET_PARAM_INVALID; } + src_dtype = param->srcT; + dst_dtype = param->dstT; if (!InferShapeDone()) { return RET_OK; @@ -97,20 +100,25 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { } auto quant_arg = !out_tensors_.front()->GetQuantParams().empty() ? out_tensors_.front()->GetQuantParams().front() : in_tensors_.front()->GetQuantParams().front(); - int ret; - if (uint8_ptr_ == nullptr) { - if (inverse_) { - ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, + int ret = RET_OK; + if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { + ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, + quant_arg.zeroPoint, num_unit_thread); + } else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) { + ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, quant_arg.zeroPoint, num_unit_thread); - } else { - ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, - quant_arg.zeroPoint, num_unit_thread); - } - } else { - if (inverse_) { - ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); - } else { - ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread); + } else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) { + ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); + } else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeInt8) { + ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread); + } else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeInt8) { + auto input_quant_arg = in_tensors_.front()->GetQuantParams().front(); + ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, num_unit_thread, + input_quant_arg.scale, input_quant_arg.zeroPoint); + if (ret) { + auto output_quant_arg = out_tensors_.front()->GetQuantParams().front(); + ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, + output_quant_arg.scale, output_quant_arg.zeroPoint, num_unit_thread); } } @@ -154,14 +162,26 @@ int QuantDTypeCastCPUKernel::Run() { out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { uint8_ptr_ = reinterpret_cast(in_tensors_[0]->data_c()); int8_ptr_ = reinterpret_cast(out_tensors_[0]->data_c()); + } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && + out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { + int8_ptr_ = reinterpret_cast(in_tensors_[0]->data_c()); + int8_out_ptr_ = reinterpret_cast(out_tensors_[0]->data_c()); + float32_ptr_ = new float[in_tensors_[0]->ElementsNum()]; } auto ret = ParallelLaunch(this->context_->thread_pool_, QuantDTypeCastRun, this, thread_n_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; + if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && + out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { + delete (float32_ptr_); + } return RET_ERROR; } - + if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && + out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { + delete (float32_ptr_); + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h index c16305e843..1560bcb63b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h @@ -39,10 +39,13 @@ class QuantDTypeCastCPUKernel : public LiteKernel { int thread_n_num_; int thread_n_stride_; int num_unit_; - int8_t *int8_ptr_; + int8_t *int8_ptr_ = nullptr; + int8_t *int8_out_ptr_ = nullptr; uint8_t *uint8_ptr_ = nullptr; - float *float32_ptr_; - bool inverse_; + float *float32_ptr_ = nullptr; + + int32_t src_dtype; + int32_t dst_dtype; }; } // namespace mindspore::kernel