Browse Source

!9794 [MSLITE]fix the error that occurs when the denominator of Rsprt is 0

From: @probiotics_53
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @hangangqiang
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
3b9340f726
6 changed files with 24 additions and 15 deletions
  1. +1
    -0
      mindspore/lite/nnacl/errorcode.h
  2. +2
    -2
      mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c
  3. +17
    -9
      mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c
  4. +2
    -0
      mindspore/lite/tools/benchmark/benchmark.cc
  5. +1
    -2
      mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc
  6. +1
    -2
      mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc

+ 1
- 0
mindspore/lite/nnacl/errorcode.h View File

@@ -30,6 +30,7 @@ typedef enum ErrorCodeFp32OpEnum {
NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC,
NNACL_ERRCODE_REVERSE_MALLOC,
NNACL_ERRCODE_SQRT_NEGATIVE,
NNACL_ERRCODE_RSQRT_NEGATIVE,
NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO,
NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO,
NNACL_ERRCODE_DIVISOR_ZERO,


+ 2
- 2
mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c View File

@@ -68,8 +68,8 @@ int ElementSqrt(const float *input, float *output, const int element_size) {
// rsqrt
int ElementRsqrt(const float *input, float *output, const int element_size) {
for (int i = 0; i < element_size; i++) {
if (input[i] <= 0) {
return NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO;
if (input[i] < 0) {
return NNACL_ERRCODE_RSQRT_NEGATIVE;
}
output[i] = 1.f / sqrtf(input[i]);
}


+ 17
- 9
mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c View File

@@ -36,10 +36,14 @@ int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float s

const float inverse_scale = 1.0f / scale;
for (int i = 0; i < size; ++i) {
int temp = round(real_values[i] * inverse_scale + zp);
temp = temp < 127 ? temp : 127;
temp = temp > -128 ? temp : -128;
quant_values[i] = (int8_t)temp;
if (isinf(real_values[i])) {
quant_values[i] = 127;
} else {
int temp = round(real_values[i] * inverse_scale + zp);
temp = temp < 127 ? temp : 127;
temp = temp > -128 ? temp : -128;
quant_values[i] = (int8_t)temp;
}
}
return NNACL_OK;
}
@@ -61,13 +65,17 @@ int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float
}

for (int i = 0; i < size; ++i) {
float temp = (float)round(real_values[i] * 1.0 / scale + zp);
if (temp > 255) {
if (isinf(real_values[i])) {
quant_values[i] = 255;
} else if (temp < 0) {
quant_values[i] = 0;
} else {
quant_values[i] = (uint8_t)temp;
float temp = (float)round(real_values[i] * 1.0 / scale + zp);
if (temp > 255) {
quant_values[i] = 255;
} else if (temp < 0) {
quant_values[i] = 0;
} else {
quant_values[i] = (uint8_t)temp;
}
}
}
return NNACL_OK;


+ 2
- 0
mindspore/lite/tools/benchmark/benchmark.cc View File

@@ -443,6 +443,8 @@ int Benchmark::PrintInputData() {
std::cout << static_cast<const uint8_t *>(in_data)[j] << " ";
} else if (tensor_data_type == TypeId::kNumberTypeInt32) {
std::cout << static_cast<const int32_t *>(in_data)[j] << " ";
} else if (tensor_data_type == TypeId::kNumberTypeInt64) {
std::cout << static_cast<const int64_t *>(in_data)[j] << " ";
} else {
MS_LOG(ERROR) << "Datatype: " << tensor_data_type << " is not supported.";
return RET_ERROR;


+ 1
- 2
mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc View File

@@ -33,8 +33,7 @@ PrimitiveC *TfliteDequantizeParser::ParseLitePrimitive(const std::unique_ptr<tfl
MS_LOG(ERROR) << "output tensor is null";
return nullptr;
}
if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) &&
(GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 ||
if ((GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 ||
GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8)) {
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) {


+ 1
- 2
mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc View File

@@ -38,8 +38,7 @@ PrimitiveC *TfliteQuantizeParser::ParseLitePrimitive(const std::unique_ptr<tflit
MS_LOG(ERROR) << "output tensor is null";
return nullptr;
}
if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) &&
(GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 ||
if ((GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 ||
GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8)) {
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) {


Loading…
Cancel
Save