| @@ -30,6 +30,7 @@ typedef enum ErrorCodeFp32OpEnum { | |||||
| NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC, | NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC, | ||||
| NNACL_ERRCODE_REVERSE_MALLOC, | NNACL_ERRCODE_REVERSE_MALLOC, | ||||
| NNACL_ERRCODE_SQRT_NEGATIVE, | NNACL_ERRCODE_SQRT_NEGATIVE, | ||||
| NNACL_ERRCODE_RSQRT_NEGATIVE, | |||||
| NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO, | NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO, | ||||
| NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO, | NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO, | ||||
| NNACL_ERRCODE_DIVISOR_ZERO, | NNACL_ERRCODE_DIVISOR_ZERO, | ||||
| @@ -68,8 +68,8 @@ int ElementSqrt(const float *input, float *output, const int element_size) { | |||||
| // rsqrt | // rsqrt | ||||
| int ElementRsqrt(const float *input, float *output, const int element_size) { | int ElementRsqrt(const float *input, float *output, const int element_size) { | ||||
| for (int i = 0; i < element_size; i++) { | for (int i = 0; i < element_size; i++) { | ||||
| if (input[i] <= 0) { | |||||
| return NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO; | |||||
| if (input[i] < 0) { | |||||
| return NNACL_ERRCODE_RSQRT_NEGATIVE; | |||||
| } | } | ||||
| output[i] = 1.f / sqrtf(input[i]); | output[i] = 1.f / sqrtf(input[i]); | ||||
| } | } | ||||
| @@ -36,10 +36,14 @@ int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float s | |||||
| const float inverse_scale = 1.0f / scale; | const float inverse_scale = 1.0f / scale; | ||||
| for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
| int temp = round(real_values[i] * inverse_scale + zp); | |||||
| temp = temp < 127 ? temp : 127; | |||||
| temp = temp > -128 ? temp : -128; | |||||
| quant_values[i] = (int8_t)temp; | |||||
| if (isinf(real_values[i])) { | |||||
| quant_values[i] = 127; | |||||
| } else { | |||||
| int temp = round(real_values[i] * inverse_scale + zp); | |||||
| temp = temp < 127 ? temp : 127; | |||||
| temp = temp > -128 ? temp : -128; | |||||
| quant_values[i] = (int8_t)temp; | |||||
| } | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -61,13 +65,17 @@ int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float | |||||
| } | } | ||||
| for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
| float temp = (float)round(real_values[i] * 1.0 / scale + zp); | |||||
| if (temp > 255) { | |||||
| if (isinf(real_values[i])) { | |||||
| quant_values[i] = 255; | quant_values[i] = 255; | ||||
| } else if (temp < 0) { | |||||
| quant_values[i] = 0; | |||||
| } else { | } else { | ||||
| quant_values[i] = (uint8_t)temp; | |||||
| float temp = (float)round(real_values[i] * 1.0 / scale + zp); | |||||
| if (temp > 255) { | |||||
| quant_values[i] = 255; | |||||
| } else if (temp < 0) { | |||||
| quant_values[i] = 0; | |||||
| } else { | |||||
| quant_values[i] = (uint8_t)temp; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| @@ -443,6 +443,8 @@ int Benchmark::PrintInputData() { | |||||
| std::cout << static_cast<const uint8_t *>(in_data)[j] << " "; | std::cout << static_cast<const uint8_t *>(in_data)[j] << " "; | ||||
| } else if (tensor_data_type == TypeId::kNumberTypeInt32) { | } else if (tensor_data_type == TypeId::kNumberTypeInt32) { | ||||
| std::cout << static_cast<const int32_t *>(in_data)[j] << " "; | std::cout << static_cast<const int32_t *>(in_data)[j] << " "; | ||||
| } else if (tensor_data_type == TypeId::kNumberTypeInt64) { | |||||
| std::cout << static_cast<const int64_t *>(in_data)[j] << " "; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Datatype: " << tensor_data_type << " is not supported."; | MS_LOG(ERROR) << "Datatype: " << tensor_data_type << " is not supported."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -33,8 +33,7 @@ PrimitiveC *TfliteDequantizeParser::ParseLitePrimitive(const std::unique_ptr<tfl | |||||
| MS_LOG(ERROR) << "output tensor is null"; | MS_LOG(ERROR) << "output tensor is null"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) && | |||||
| (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || | |||||
| if ((GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || | |||||
| GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8)) { | GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8)) { | ||||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | ||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| @@ -38,8 +38,7 @@ PrimitiveC *TfliteQuantizeParser::ParseLitePrimitive(const std::unique_ptr<tflit | |||||
| MS_LOG(ERROR) << "output tensor is null"; | MS_LOG(ERROR) << "output tensor is null"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) && | |||||
| (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || | |||||
| if ((GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || | |||||
| GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8)) { | GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8)) { | ||||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | ||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||