diff --git a/mindspore/lite/nnacl/fp16/quant_dtype_cast_fp16.c b/mindspore/lite/nnacl/fp16/quant_dtype_cast_fp16.c index 9abc836424..d1c18e5bdf 100644 --- a/mindspore/lite/nnacl/fp16/quant_dtype_cast_fp16.c +++ b/mindspore/lite/nnacl/fp16/quant_dtype_cast_fp16.c @@ -29,12 +29,16 @@ int DoDequantizeInt8ToFp16(int8_t *quant_values, float16_t *real_values, float s return NNACL_OK; } -int DoQuantizeToInt8FromFp16(float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { +int DoQuantizeFp16ToInt8(float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { if (quant_values == NULL || real_values == NULL) { return NNACL_PARAM_INVALID; } for (int i = 0; i < size; ++i) { + if (isinf(real_values[i])) { + quant_values[i] = 127; + continue; + } float temp = round((float)real_values[i] / scale + zp); if (temp > 127) { quant_values[i] = 127; @@ -46,3 +50,37 @@ int DoQuantizeToInt8FromFp16(float16_t *real_values, int8_t *quant_values, float } return NNACL_OK; } + +int DoDequantizeUInt8ToFp16(uint8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size) { + uint8_t zp_ = (uint8_t)zp; + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + real_values[i] = (quant_values[i] - zp_) * scale; + } + return NNACL_OK; +} + +int DoQuantizeFp16ToUInt8(float16_t *real_values, uint8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + if (isinf(real_values[i])) { + quant_values[i] = 255; + continue; + } + float temp = round((float)real_values[i] / scale + zp); + if (temp > 255.0f) { + quant_values[i] = 255; + } else if (temp < 0.0f) { + quant_values[i] = 0; + } else { + quant_values[i] = (uint8_t)temp; + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp16/quant_dtype_cast_fp16.h b/mindspore/lite/nnacl/fp16/quant_dtype_cast_fp16.h index 6442355ba6..9019bf565f 100644 --- a/mindspore/lite/nnacl/fp16/quant_dtype_cast_fp16.h +++ b/mindspore/lite/nnacl/fp16/quant_dtype_cast_fp16.h @@ -27,7 +27,10 @@ extern "C" { #endif int DoDequantizeInt8ToFp16(int8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size); -int DoQuantizeToInt8FromFp16(float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size); +int DoQuantizeFp16ToInt8(float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size); + +int DoDequantizeUInt8ToFp16(uint8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp16ToUInt8(float16_t *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/quant_dtype_cast_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/quant_dtype_cast_fp16.cc index 8373ddee83..39730665e2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/quant_dtype_cast_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/quant_dtype_cast_fp16.cc @@ -47,13 +47,29 @@ int QuantDTypeCastFp16CPUKernel::Init() { MS_LOG(ERROR) << "param data type and tensor data type do not match."; return RET_ERROR; } - inverse_ = false; + int_to_float_ = false; + is_uint8_ = false; } else if (param->srcT == kNumberTypeInt8) { if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeFloat16) { MS_LOG(ERROR) << "param data type and tensor data type do not match."; return RET_ERROR; } - inverse_ = true; + int_to_float_ = true; + is_uint8_ = false; + } else if (param->dstT == kNumberTypeUInt8) { + if (in_tensor->data_type() != kNumberTypeFloat16 || out_tensor->data_type() != kNumberTypeUInt8) { + MS_LOG(ERROR) << "param data type and tensor data type do not match."; + return RET_ERROR; + } + int_to_float_ = false; + is_uint8_ = true; + } else if (param->srcT == kNumberTypeUInt8) { + if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeFloat16) { + MS_LOG(ERROR) << "param data type and tensor data type do not match."; + return RET_ERROR; + } + int_to_float_ = true; + is_uint8_ = true; } else { MS_LOG(ERROR) << "param data type not supported:" << " src: " << param->srcT << " dst: " << param->dstT; @@ -87,14 +103,26 @@ int QuantDTypeCastFp16CPUKernel::QuantDTypeCast(int task_id) { auto quant_arg = !out_tensors_.front()->quant_params().empty() ? out_tensors_.front()->quant_params().front() : in_tensors_.front()->quant_params().front(); int ret; - MS_ASSERT(int8_ptr_); MS_ASSERT(float16_ptr_); - if (inverse_) { - ret = DoDequantizeInt8ToFp16(int8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale, + if (!is_uint8_) { + MS_ASSERT(int8_ptr_); + if (int_to_float_) { + ret = DoDequantizeInt8ToFp16(int8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale, + quant_arg.zeroPoint, num_unit_thread); + } else { + ret = DoQuantizeFp16ToInt8(float16_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, quant_arg.zeroPoint, num_unit_thread); + } } else { - ret = DoQuantizeToInt8FromFp16(float16_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, - quant_arg.zeroPoint, num_unit_thread); + // uint8 + MS_ASSERT(uint8_ptr_); + if (int_to_float_) { + ret = DoDequantizeUInt8ToFp16(uint8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale, + quant_arg.zeroPoint, num_unit_thread); + } else { + ret = DoQuantizeFp16ToUInt8(float16_ptr_ + thread_offset, uint8_ptr_ + thread_offset, quant_arg.scale, + quant_arg.zeroPoint, num_unit_thread); + } } if (ret != RET_OK) { @@ -123,6 +151,14 @@ int QuantDTypeCastFp16CPUKernel::Run() { out_tensors_.at(0)->data_type() == TypeId::kNumberTypeInt8) { float16_ptr_ = reinterpret_cast(in_tensors_.at(0)->data_c()); int8_ptr_ = reinterpret_cast(out_tensors_.at(0)->data_c()); + } else if (in_tensors_.at(0)->data_type() == TypeId::kNumberTypeUInt8 && + out_tensors_.at(0)->data_type() == TypeId::kNumberTypeFloat16) { + uint8_ptr_ = reinterpret_cast(in_tensors_.at(0)->data_c()); + float16_ptr_ = reinterpret_cast(out_tensors_.at(0)->data_c()); + } else if (in_tensors_.at(0)->data_type() == TypeId::kNumberTypeFloat16 && + out_tensors_.at(0)->data_type() == TypeId::kNumberTypeUInt8) { + float16_ptr_ = reinterpret_cast(in_tensors_.at(0)->data_c()); + uint8_ptr_ = reinterpret_cast(out_tensors_.at(0)->data_c()); } else { MS_LOG(ERROR) << "QuantDTypeCastFp16 not support input or output type"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/quant_dtype_cast_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/quant_dtype_cast_fp16.h index 0c7dc353b5..bd54faa0a4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/quant_dtype_cast_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/quant_dtype_cast_fp16.h @@ -41,8 +41,10 @@ class QuantDTypeCastFp16CPUKernel : public LiteKernel { int thread_n_stride_; int num_unit_; int8_t *int8_ptr_; + uint8_t *uint8_ptr_; float16_t *float16_ptr_; - bool inverse_; + bool int_to_float_; + bool is_uint8_; }; } // namespace mindspore::kernel