From f3b794bc267e5a68b44ecfc3f7e3923aff199fed Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Fri, 14 Aug 2020 11:17:47 +0800 Subject: [PATCH] support fp16 model --- mindspore/lite/schema/model.fbs | 1 + mindspore/lite/schema/ops.fbs | 5 + mindspore/lite/src/ops/cast.cc | 2 +- mindspore/lite/src/ops/ops.h | 2 +- .../lite/src/runtime/kernel/arm/fp32/cast.cc | 6 + .../runtime/kernel/arm/nnacl/common_func.c | 1 + .../src/runtime/kernel/arm/nnacl/fp32/cast.c | 7 + .../src/runtime/kernel/arm/nnacl/fp32/cast.h | 1 + .../kernel/arm/nnacl/fp32/common_func.c | 121 ++++++++++++++++++ .../kernel/arm/nnacl/fp32/common_func.h | 3 + .../parser/tflite/tflite_cast_parser.h | 16 +-- .../parser/tflite/tflite_dequantize_parser.cc | 68 ++++++++++ .../parser/tflite/tflite_dequantize_parser.h | 39 ++++++ .../parser/tflite/tflite_model_parser.cc | 10 +- .../parser/tflite/tflite_node_parser.h | 12 ++ 15 files changed, 273 insertions(+), 21 deletions(-) create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 0e09da8614..f9f61af25e 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -90,6 +90,7 @@ union PrimitiveType { Rsqrt, ExpandDims, Tile, + Fp16Cast, Cast, Shape, Nchw2Nhwc, diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index cbb07b0be8..b64eb61d80 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -576,6 +576,11 @@ table Cast { dstT: int; } +table Fp16Cast { + srcT: int; + dstT: int; +} + table QuantDTypeCast { srcT: int; dstT: int; diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index 565f8de767..7ac3601727 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -46,7 +46,7 @@ int Cast::InferShape(std::vector inputs_, std::vectorSetFormat(input->GetFormat()); output->set_shape(input->shape()); - output->set_data_type(input->data_type()); + output->set_data_type(TypeId::kNumberTypeFloat32); return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/ops/ops.h b/mindspore/lite/src/ops/ops.h index 3f0de59261..64de6cdfe7 100644 --- a/mindspore/lite/src/ops/ops.h +++ b/mindspore/lite/src/ops/ops.h @@ -37,7 +37,7 @@ constexpr uint32_t kNHWC_w_index = 2; constexpr uint32_t kNHWC_c_index = 3; constexpr uint32_t kDimension_4d = 4; -const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32}; +const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeFloat16}; class Primitive { public: diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc index 3363ef6d51..e74cce6b29 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -27,6 +27,7 @@ using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Cast; +using mindspore::schema::PrimitiveType_Fp16Cast; namespace mindspore::kernel { namespace { @@ -87,6 +88,10 @@ int CastCPUKernel::DoCast(int thread_id) { Int32ToFloat32(reinterpret_cast(input->Data()) + offset, reinterpret_cast(output_data) + offset, data_num); break; + case kNumberTypeFloat16: + Fp16ToFloat32(reinterpret_cast(input->Data()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; default: MS_LOG(ERROR) << "Unsupport input data type " << input_data_type; return RET_ERROR; @@ -139,4 +144,5 @@ kernel::LiteKernel *CpuCastFp32KernelCreator(const std::vector 0) { + int res = (0x7FC00000 | sign); + int *iRes = &res; + auto fres = (float)(*iRes); + return fres; + } + // inf + int res = (0x7F800000 | sign); + int *iRes = &res; + auto fres = (float)(*iRes); + return fres; + } + if (expHalf16 != 0) { + exp1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS) << FP16_SIGNIFICAND); // exponents converted to float32 bias + int res = (exp1 | mantissa1); + res = res << (FP32_SIGNIFICAND - FP16_SIGNIFICAND); + res = (res | sign); + int *iRes = &res; + auto fres = (float)(*iRes); + return fres; + } + + int xmm1 = exp1 > (1 << FP16_SIGNIFICAND) ? exp1 : (1 << FP16_SIGNIFICAND); + xmm1 = (xmm1 << (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + xmm1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS - FP16_SIGNIFICAND) + << FP32_SIGNIFICAND); // add the bias difference to xmm1 + xmm1 = xmm1 | sign; // Combine with the sign mask + + auto res = (float)(mantissa1); // Convert mantissa to float + int *ixmm1 = NULL; + ixmm1 = &xmm1; + res *= (float)(*ixmm1); + + return res; +} + +// __gnu_f2h_ieee +int16_t Float32ToShort(float srcValue) { + float *psrcValue = NULL; + psrcValue = &srcValue; + auto srcValueBit = (unsigned int)(*psrcValue); + int sign = srcValueBit >> (FP32_BIT_SIZE - 1); + int mantissa = srcValueBit & 0x007FFFFF; + // exponent + int exp = ((srcValueBit & 0x7F800000) >> FP32_SIGNIFICAND) + FP16_EXPONENT_BIAS - FP32_EXPONENT_BIAS; + int16_t res; + if (exp > 0 && exp < FP16_EXPONENT_MAX) { + // use rte rounding mode, round the significand, combine sign, exponent and significand into a short. + res = (sign << (FP16_BIT_SIZE - 1)) | (exp << FP16_SIGNIFICAND) | + ((mantissa + 0x00001000) >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } else if (srcValueBit == 0) { + res = 0; + } else { + if (exp <= 0) { + if (exp < FP16_EXPONENT_MIN) { + // value is less than min half float point + res = 0; + } else { + // normalized single, magnitude is less than min normal half float point. + mantissa = (mantissa | 0x00800000) >> (1 - exp); + // round to nearest + if ((mantissa & 0x00001000) > 0) { + mantissa = mantissa + 0x00002000; + } + // combine sign & mantissa (exp is zero to get denormalized number) + res = (sign << FP16_EXPONENT_BIAS) | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } + } else if (exp == (FP32_EXPONENT_MAX - FP32_EXPONENT_BIAS + FP16_EXPONENT_BIAS)) { + if (mantissa == 0) { + // input float is infinity, return infinity half + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; + } else { + // input float is NaN, return half NaN + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00 | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } + } else { + // exp > 0, normalized single, round to nearest + if ((mantissa & 0x00001000) > 0) { + mantissa = mantissa + 0x00002000; + if ((mantissa & 0x00800000) > 0) { + mantissa = 0; + exp = exp + 1; + } + } + if (exp > FP16_EXPONENT_MAX) { + // exponent overflow - return infinity half + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; + } else { + // combine sign, exp and mantissa into normalized half + res = (sign << FP16_EXPONENT_BIAS) | (exp << FP16_SIGNIFICAND) | + (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } + } + } + return res; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h index 110ffbdbe3..a25b4d21fd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/common_func.h @@ -37,6 +37,9 @@ void MatrixSub(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stri size_t row, size_t col); void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr, size_t row, size_t col, size_t c_stride, size_t x_stride); +int16_t Float32ToShort(float srcValue); + +float ShortToFloat32(int16_t srcValue); #ifdef ENABLE_ARM void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h index 843a144929..ae1dca284c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h @@ -14,12 +14,11 @@ * limitations under the License. */ -#ifndef LITE_TFLITE_CAST_PARSER_H +#ifndef LITE_TFLITE_CAST_PARSER_ #define LITE_TFLITE_CAST_PARSER_H #include #include -#include #include "tools/converter/parser/tflite/tflite_node_parser.h" #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" @@ -35,19 +34,6 @@ class TfliteCastParser : public TfliteNodeParser { const std::vector> &tflite_opset, schema::CNodeT *op, TensorCache *tensor_cache, bool quantized_model) override; - - private: - std::map dtype_map = { - {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, - {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, - {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, - {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, - {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, - {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, - {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, - {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, - {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, - }; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc new file mode 100644 index 0000000000..74bde414e5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tflite/tflite_dequantize_parser.h" +#include +#include +#include "tools/common/node_util.h" + +namespace mindspore { +namespace lite { +STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, + schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; + std::unique_ptr attr(new schema::CastT); + + // get the dequantize input tensor + const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]]; + if (in_tensor == nullptr) { + MS_LOG(ERROR) << "weight_tensor is null"; + return RET_NULL_PTR; + } + attr->srcT = dtype_map[in_tensor->type]; + + const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]]; + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "tensor is null"; + return RET_NULL_PTR; + } + attr->dstT = dtype_map[out_tensor->type]; + std::vector weight_tensors{in_tensor.get()}; + if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { + MS_LOG(ERROR) << "parse weight failed"; + return RET_ERROR; + } + + op->primitive->value.type = schema::PrimitiveType_Fp16Cast; + op->primitive->value.value = attr.release(); + return 0; +} + +TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h new file mode 100644 index 0000000000..3d6e521d7d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_TFLITE_DEQUANTIZE_PARSER_H +#define LITE_TFLITE_DEQUANTIZE_PARSER_H + +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteDequantizeParser : public TfliteNodeParser { + public: + TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} + + STATUS Parse(const std::unique_ptr &tfliteOp, + const std::vector> &tfliteTensors, + const std::vector> &tfliteModelBuffer, + const std::vector> &tfliteOpSet, schema::CNodeT *op, + TensorCache *tensor_cache, bool quantizedModel) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_TFLITE_DEQUANTIZE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index d240a9f295..7ea6fecfd3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -208,10 +208,12 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr &tflite_ return RET_ERROR; } - status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache); - if (status != RET_OK) { - MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed"; - return RET_ERROR; + if (quantType != schema::QuantType_QUANT_NONE) { + status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache); + if (status != RET_OK) { + MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed"; + return RET_ERROR; + } } subGraph->nodes.emplace_back(std::move(op)); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h index 3a3828e0cd..de5f0d1b28 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h @@ -19,6 +19,7 @@ #include #include +#include #include #include "utils/log_adapter.h" #include "schema/inner/model_generated.h" @@ -126,6 +127,17 @@ class TfliteNodeParser { protected: const std::string &name; + std::map dtype_map = { + {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, + {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, + {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, + {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, + {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, + {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, + {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, + {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, + {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, + }; }; } // namespace lite } // namespace mindspore