From 8a29d90d3ca9e0e5c2abf64b767dd5fd9ceee060 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Mon, 17 Aug 2020 09:37:10 +0800 Subject: [PATCH] solve fp16tofp32 bug --- mindspore/lite/schema/model.fbs | 1 - mindspore/lite/schema/ops.fbs | 5 - .../lite/src/runtime/kernel/arm/fp32/cast.cc | 12 +- .../src/runtime/kernel/arm/nnacl/fp32/cast.c | 8 +- .../src/runtime/kernel/arm/nnacl/fp32/cast.h | 3 +- .../kernel/arm/nnacl/fp32/common_func.c | 175 +++++++----------- .../kernel/arm/nnacl/fp32/common_func.h | 5 +- .../parser/tflite/tflite_dequantize_parser.cc | 2 +- .../fusion/constant_folding_fusion.cc | 1 + 9 files changed, 85 insertions(+), 127 deletions(-) diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index caa2b91dd9..0c0ad360b6 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -90,7 +90,6 @@ union PrimitiveType { Rsqrt, ExpandDims, Tile, - Fp16Cast, Cast, Shape, Nchw2Nhwc, diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 0ec683e2ec..2e93983c16 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -581,11 +581,6 @@ table Cast { dstT: int; } -table Fp16Cast { - srcT: int; - dstT: int; -} - table QuantDTypeCast { srcT: int; dstT: int; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc index 1eef2fb828..bb8f45591f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -27,7 +27,6 @@ using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Cast; -using mindspore::schema::PrimitiveType_Fp16Cast; namespace mindspore::kernel { namespace { @@ -74,6 +73,9 @@ int CastCPUKernel::DoCast(int thread_id) { if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) { Float32ToInt32(reinterpret_cast(input->Data()) + offset, reinterpret_cast(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeFloat16) { + Float32ToFp16(reinterpret_cast(input->Data()) + offset, + reinterpret_cast(output_data) + offset, data_num); } else { MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; return RET_ERROR; @@ -89,8 +91,8 @@ int CastCPUKernel::DoCast(int thread_id) { reinterpret_cast(output_data) + offset, data_num); break; case kNumberTypeFloat16: - Fp16ToFloat32(reinterpret_cast(input->Data()) + offset, - reinterpret_cast(output_data) + offset, data_num); + Fp16ToFloat32(reinterpret_cast(input->Data()) + offset, + reinterpret_cast(output_data) + offset, data_num); break; default: MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; @@ -144,5 +146,7 @@ kernel::LiteKernel *CpuCastFp32KernelCreator(const std::vector 0) { - int res = (0x7FC00000 | sign); - int *iRes = &res; - auto fres = (float)(*iRes); - return fres; - } - // inf - int res = (0x7F800000 | sign); - int *iRes = &res; - auto fres = (float)(*iRes); - return fres; - } - if (expHalf16 != 0) { - exp1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS) << FP16_SIGNIFICAND); // exponents converted to float32 bias - int res = (exp1 | mantissa1); - res = res << (FP32_SIGNIFICAND - FP16_SIGNIFICAND); - res = (res | sign); - int *iRes = &res; - auto fres = (float)(*iRes); - return fres; +union float32_bits { + unsigned int u; + float f; +}; +typedef union float32_bits float32_bits; + +float ShortToFloat32(uint16_t srcValue) { + const float32_bits magic = {113 << 23}; + const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift + float32_bits o; + + o.u = (srcValue & 0x7fff) << 13; // exponent/mantissa bits + unsigned int exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // renormalize } - int xmm1 = exp1 > (1 << FP16_SIGNIFICAND) ? exp1 : (1 << FP16_SIGNIFICAND); - xmm1 = (xmm1 << (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); - xmm1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS - FP16_SIGNIFICAND) - << FP32_SIGNIFICAND); // add the bias difference to xmm1 - xmm1 = xmm1 | sign; // Combine with the sign mask - - auto res = (float)(mantissa1); // Convert mantissa to float - int *ixmm1 = NULL; - ixmm1 = &xmm1; - res *= (float)(*ixmm1); - - return res; + o.u |= (srcValue & 0x8000) << 16; // sign bit + return o.f; } -// __gnu_f2h_ieee -int16_t Float32ToShort(float srcValue) { - float *psrcValue = NULL; - psrcValue = &srcValue; - auto srcValueBit = (unsigned int)(*psrcValue); - int sign = srcValueBit >> (FP32_BIT_SIZE - 1); - int mantissa = srcValueBit & 0x007FFFFF; - // exponent - int exp = ((srcValueBit & 0x7F800000) >> FP32_SIGNIFICAND) + FP16_EXPONENT_BIAS - FP32_EXPONENT_BIAS; - int16_t res; - if (exp > 0 && exp < FP16_EXPONENT_MAX) { - // use rte rounding mode, round the significand, combine sign, exponent and significand into a short. - res = (sign << (FP16_BIT_SIZE - 1)) | (exp << FP16_SIGNIFICAND) | - ((mantissa + 0x00001000) >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); - } else if (srcValueBit == 0) { - res = 0; - } else { - if (exp <= 0) { - if (exp < FP16_EXPONENT_MIN) { - // value is less than min half float point - res = 0; - } else { - // normalized single, magnitude is less than min normal half float point. - mantissa = (mantissa | 0x00800000) >> (1 - exp); - // round to nearest - if ((mantissa & 0x00001000) > 0) { - mantissa = mantissa + 0x00002000; - } - // combine sign & mantissa (exp is zero to get denormalized number) - res = (sign << FP16_EXPONENT_BIAS) | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); - } - } else if (exp == (FP32_EXPONENT_MAX - FP32_EXPONENT_BIAS + FP16_EXPONENT_BIAS)) { - if (mantissa == 0) { - // input float is infinity, return infinity half - res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; - } else { - // input float is NaN, return half NaN - res = (sign << FP16_EXPONENT_BIAS) | 0x7C00 | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); - } +uint16_t Float32ToShort(float srcValue) { + float32_bits f; + f.f = srcValue; + + const float32_bits f32infty = {255 << 23}; + const float32_bits f16max = {(127 + 16) << 23}; + const float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; + unsigned int sign_mask = 0x80000000u; + uint16_t o; + + unsigned int sign = f.u & sign_mask; + f.u ^= sign; + + // NOTE all the integer compares in this function can be safely + // compiled into signed compares since all operands are below + // 0x80000000. Important if you want fast straight SSE2 code + // (since there's no unsigned PCMPGTD). + + if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) + o = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + } else { // (De)normalized number or zero + if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.f += denorm_magic.f; + + // and one integer subtract of the bias later, we have our final float! + o = (uint16_t)(f.u - denorm_magic.u); } else { - // exp > 0, normalized single, round to nearest - if ((mantissa & 0x00001000) > 0) { - mantissa = mantissa + 0x00002000; - if ((mantissa & 0x00800000) > 0) { - mantissa = 0; - exp = exp + 1; - } - } - if (exp > FP16_EXPONENT_MAX) { - // exponent overflow - return infinity half - res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; - } else { - // combine sign, exp and mantissa into normalized half - res = (sign << FP16_EXPONENT_BIAS) | (exp << FP16_SIGNIFICAND) | - (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); - } + unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + f.u += ((unsigned int)(15 - 127) << 23) + 0xfff; + // rounding bias part 2 + f.u += mant_odd; + // take the bits! + o = (uint16_t)(f.u >> 13); } } - return res; + + o |= (uint16_t)(sign >> 16); + return o; } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h index 4b7dfe4120..9f590aefff 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h @@ -37,9 +37,10 @@ void MatrixSub(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stri size_t row, size_t col); void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr, size_t row, size_t col, size_t c_stride, size_t x_stride); -int16_t Float32ToShort(float srcValue); +float ShortToFloat32(uint16_t srcValue); + +uint16_t Float32ToShort(float srcValue); -float ShortToFloat32(int16_t srcValue); #ifdef ENABLE_ARM void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index 74bde414e5..bdfca915c2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -58,7 +58,7 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &t return RET_ERROR; } - op->primitive->value.type = schema::PrimitiveType_Fp16Cast; + op->primitive->value.type = schema::PrimitiveType_Cast; op->primitive->value.value = attr.release(); return 0; } diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 8197881445..b1d2e03b89 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -161,6 +161,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An MS_LOG(EXCEPTION) << "run kernel failed, name: " << lite_kernel->name(); } auto new_parameter = CreateNewParamter(func_graph, output_tensors.front()); + new_parameter->set_name(input_node->fullname_with_scope()); any_node->set_input(i, new_parameter); } }