|
|
|
@@ -46,30 +46,33 @@ int QuantDTypeCastCPUKernel::Init() { |
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
inverse_ = false; |
|
|
|
} else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeFloat32) { |
|
|
|
if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeFloat32) { |
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
inverse_ = true; |
|
|
|
} else if (param->srcT == kNumberTypeUInt8 && param->dstT == kNumberTypeInt8) { |
|
|
|
if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeInt8) { |
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
inverse_ = false; |
|
|
|
} else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeInt8) { |
|
|
|
if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeInt8) { |
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeUInt8) { |
|
|
|
if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeUInt8) { |
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
inverse_ = true; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "param data type not supported:" |
|
|
|
<< " src: " << param->srcT << " dst: " << param->dstT; |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
src_dtype = param->srcT; |
|
|
|
dst_dtype = param->dstT; |
|
|
|
|
|
|
|
if (!InferShapeDone()) { |
|
|
|
return RET_OK; |
|
|
|
@@ -97,20 +100,25 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { |
|
|
|
} |
|
|
|
auto quant_arg = !out_tensors_.front()->GetQuantParams().empty() ? out_tensors_.front()->GetQuantParams().front() |
|
|
|
: in_tensors_.front()->GetQuantParams().front(); |
|
|
|
int ret; |
|
|
|
if (uint8_ptr_ == nullptr) { |
|
|
|
if (inverse_) { |
|
|
|
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, |
|
|
|
int ret = RET_OK; |
|
|
|
if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { |
|
|
|
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, |
|
|
|
quant_arg.zeroPoint, num_unit_thread); |
|
|
|
} else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) { |
|
|
|
ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, |
|
|
|
quant_arg.zeroPoint, num_unit_thread); |
|
|
|
} else { |
|
|
|
ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, |
|
|
|
quant_arg.zeroPoint, num_unit_thread); |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (inverse_) { |
|
|
|
ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); |
|
|
|
} else { |
|
|
|
ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread); |
|
|
|
} else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) { |
|
|
|
ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); |
|
|
|
} else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeInt8) { |
|
|
|
ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread); |
|
|
|
} else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeInt8) { |
|
|
|
auto input_quant_arg = in_tensors_.front()->GetQuantParams().front(); |
|
|
|
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, num_unit_thread, |
|
|
|
input_quant_arg.scale, input_quant_arg.zeroPoint); |
|
|
|
if (ret) { |
|
|
|
auto output_quant_arg = out_tensors_.front()->GetQuantParams().front(); |
|
|
|
ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, |
|
|
|
output_quant_arg.scale, output_quant_arg.zeroPoint, num_unit_thread); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -154,14 +162,26 @@ int QuantDTypeCastCPUKernel::Run() { |
|
|
|
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { |
|
|
|
uint8_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_[0]->data_c()); |
|
|
|
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c()); |
|
|
|
} else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && |
|
|
|
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { |
|
|
|
int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c()); |
|
|
|
int8_out_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c()); |
|
|
|
float32_ptr_ = new float[in_tensors_[0]->ElementsNum()]; |
|
|
|
} |
|
|
|
|
|
|
|
auto ret = ParallelLaunch(this->context_->thread_pool_, QuantDTypeCastRun, this, thread_n_num_); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; |
|
|
|
if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && |
|
|
|
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { |
|
|
|
delete (float32_ptr_); |
|
|
|
} |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && |
|
|
|
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { |
|
|
|
delete (float32_ptr_); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
|