|
|
|
@@ -47,13 +47,29 @@ int QuantDTypeCastFp16CPUKernel::Init() { |
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
inverse_ = false; |
|
|
|
int_to_float_ = false; |
|
|
|
is_uint8_ = false; |
|
|
|
} else if (param->srcT == kNumberTypeInt8) { |
|
|
|
if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeFloat16) { |
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
inverse_ = true; |
|
|
|
int_to_float_ = true; |
|
|
|
is_uint8_ = false; |
|
|
|
} else if (param->dstT == kNumberTypeUInt8) { |
|
|
|
if (in_tensor->data_type() != kNumberTypeFloat16 || out_tensor->data_type() != kNumberTypeUInt8) { |
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
int_to_float_ = false; |
|
|
|
is_uint8_ = true; |
|
|
|
} else if (param->srcT == kNumberTypeUInt8) { |
|
|
|
if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeFloat16) { |
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
int_to_float_ = true; |
|
|
|
is_uint8_ = true; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "param data type not supported:" |
|
|
|
<< " src: " << param->srcT << " dst: " << param->dstT; |
|
|
|
@@ -87,14 +103,26 @@ int QuantDTypeCastFp16CPUKernel::QuantDTypeCast(int task_id) { |
|
|
|
auto quant_arg = !out_tensors_.front()->quant_params().empty() ? out_tensors_.front()->quant_params().front() |
|
|
|
: in_tensors_.front()->quant_params().front(); |
|
|
|
int ret; |
|
|
|
MS_ASSERT(int8_ptr_); |
|
|
|
MS_ASSERT(float16_ptr_); |
|
|
|
if (inverse_) { |
|
|
|
ret = DoDequantizeInt8ToFp16(int8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale, |
|
|
|
if (!is_uint8_) { |
|
|
|
MS_ASSERT(int8_ptr_); |
|
|
|
if (int_to_float_) { |
|
|
|
ret = DoDequantizeInt8ToFp16(int8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale, |
|
|
|
quant_arg.zeroPoint, num_unit_thread); |
|
|
|
} else { |
|
|
|
ret = DoQuantizeFp16ToInt8(float16_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, |
|
|
|
quant_arg.zeroPoint, num_unit_thread); |
|
|
|
} |
|
|
|
} else { |
|
|
|
ret = DoQuantizeToInt8FromFp16(float16_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, |
|
|
|
quant_arg.zeroPoint, num_unit_thread); |
|
|
|
// uint8 |
|
|
|
MS_ASSERT(uint8_ptr_); |
|
|
|
if (int_to_float_) { |
|
|
|
ret = DoDequantizeUInt8ToFp16(uint8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale, |
|
|
|
quant_arg.zeroPoint, num_unit_thread); |
|
|
|
} else { |
|
|
|
ret = DoQuantizeFp16ToUInt8(float16_ptr_ + thread_offset, uint8_ptr_ + thread_offset, quant_arg.scale, |
|
|
|
quant_arg.zeroPoint, num_unit_thread); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (ret != RET_OK) { |
|
|
|
@@ -123,6 +151,14 @@ int QuantDTypeCastFp16CPUKernel::Run() { |
|
|
|
out_tensors_.at(0)->data_type() == TypeId::kNumberTypeInt8) { |
|
|
|
float16_ptr_ = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c()); |
|
|
|
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data_c()); |
|
|
|
} else if (in_tensors_.at(0)->data_type() == TypeId::kNumberTypeUInt8 && |
|
|
|
out_tensors_.at(0)->data_type() == TypeId::kNumberTypeFloat16) { |
|
|
|
uint8_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_.at(0)->data_c()); |
|
|
|
float16_ptr_ = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c()); |
|
|
|
} else if (in_tensors_.at(0)->data_type() == TypeId::kNumberTypeFloat16 && |
|
|
|
out_tensors_.at(0)->data_type() == TypeId::kNumberTypeUInt8) { |
|
|
|
float16_ptr_ = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c()); |
|
|
|
uint8_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->data_c()); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "QuantDTypeCastFp16 not support input or output type"; |
|
|
|
return RET_ERROR; |
|
|
|
|