From a39cae0df6014bcefd1abc2fbca0f29f2dd86581 Mon Sep 17 00:00:00 2001 From: zengxianglong Date: Thu, 10 Dec 2020 21:05:18 +0800 Subject: [PATCH] fix some bugs --- mindspore/lite/nnacl/errorcode.h | 1 + .../lite/nnacl/fp32/arithmetic_self_fp32.c | 4 +-- .../lite/nnacl/int8/quant_dtype_cast_int8.c | 26 ++++++++++++------- mindspore/lite/tools/benchmark/benchmark.cc | 2 ++ .../parser/tflite/tflite_dequantize_parser.cc | 3 +-- .../parser/tflite/tflite_quantize_parser.cc | 3 +-- 6 files changed, 24 insertions(+), 15 deletions(-) diff --git a/mindspore/lite/nnacl/errorcode.h b/mindspore/lite/nnacl/errorcode.h index 09ee45ff68..50d7d76bce 100644 --- a/mindspore/lite/nnacl/errorcode.h +++ b/mindspore/lite/nnacl/errorcode.h @@ -30,6 +30,7 @@ typedef enum ErrorCodeFp32OpEnum { NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC, NNACL_ERRCODE_REVERSE_MALLOC, NNACL_ERRCODE_SQRT_NEGATIVE, + NNACL_ERRCODE_RSQRT_NEGATIVE, NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO, NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO, NNACL_ERRCODE_DIVISOR_ZERO, diff --git a/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c b/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c index 55a05e568f..bb2fd299c6 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c +++ b/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c @@ -68,8 +68,8 @@ int ElementSqrt(const float *input, float *output, const int element_size) { // rsqrt int ElementRsqrt(const float *input, float *output, const int element_size) { for (int i = 0; i < element_size; i++) { - if (input[i] <= 0) { - return NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO; + if (input[i] < 0) { + return NNACL_ERRCODE_RSQRT_NEGATIVE; } output[i] = 1.f / sqrtf(input[i]); } diff --git a/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c b/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c index dfa4e72091..102ab6b091 100644 --- a/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c +++ b/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c @@ -36,10 +36,14 @@ int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float s const float inverse_scale = 1.0f / scale; for (int i = 0; i < size; ++i) { - int temp = round(real_values[i] * inverse_scale + zp); - temp = temp < 127 ? temp : 127; - temp = temp > -128 ? temp : -128; - quant_values[i] = (int8_t)temp; + if (isinf(real_values[i])) { + quant_values[i] = 127; + } else { + int temp = round(real_values[i] * inverse_scale + zp); + temp = temp < 127 ? temp : 127; + temp = temp > -128 ? temp : -128; + quant_values[i] = (int8_t)temp; + } } return NNACL_OK; } @@ -61,13 +65,17 @@ int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float } for (int i = 0; i < size; ++i) { - float temp = (float)round(real_values[i] * 1.0 / scale + zp); - if (temp > 255) { + if (isinf(real_values[i])) { quant_values[i] = 255; - } else if (temp < 0) { - quant_values[i] = 0; } else { - quant_values[i] = (uint8_t)temp; + float temp = (float)round(real_values[i] * 1.0 / scale + zp); + if (temp > 255) { + quant_values[i] = 255; + } else if (temp < 0) { + quant_values[i] = 0; + } else { + quant_values[i] = (uint8_t)temp; + } } } return NNACL_OK; diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc index 9209ae1767..ac7c936a64 100644 --- a/mindspore/lite/tools/benchmark/benchmark.cc +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -443,6 +443,8 @@ int Benchmark::PrintInputData() { std::cout << static_cast(in_data)[j] << " "; } else if (tensor_data_type == TypeId::kNumberTypeInt32) { std::cout << static_cast(in_data)[j] << " "; + } else if (tensor_data_type == TypeId::kNumberTypeInt64) { + std::cout << static_cast(in_data)[j] << " "; } else { MS_LOG(ERROR) << "Datatype: " << tensor_data_type << " is not supported."; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index 44465ddb5c..5f8f277575 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -33,8 +33,7 @@ PrimitiveC *TfliteDequantizeParser::ParseLitePrimitive(const std::unique_ptrtype) != GetTfliteDataType(out_tensor->type) && - (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || + if ((GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8)) { std::unique_ptr attr = std::make_unique(); if (attr == nullptr) { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc index 0c61348cba..4351a6cb16 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc @@ -38,8 +38,7 @@ PrimitiveC *TfliteQuantizeParser::ParseLitePrimitive(const std::unique_ptrtype) != GetTfliteDataType(out_tensor->type) && - (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || + if ((GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8)) { std::unique_ptr attr = std::make_unique(); if (attr == nullptr) {