| @@ -90,6 +90,7 @@ union PrimitiveType { | |||||
| Rsqrt, | Rsqrt, | ||||
| ExpandDims, | ExpandDims, | ||||
| Tile, | Tile, | ||||
| Fp16Cast, | |||||
| Cast, | Cast, | ||||
| Shape, | Shape, | ||||
| Nchw2Nhwc, | Nchw2Nhwc, | ||||
| @@ -576,6 +576,11 @@ table Cast { | |||||
| dstT: int; | dstT: int; | ||||
| } | } | ||||
| table Fp16Cast { | |||||
| srcT: int; | |||||
| dstT: int; | |||||
| } | |||||
| table QuantDTypeCast { | table QuantDTypeCast { | ||||
| srcT: int; | srcT: int; | ||||
| dstT: int; | dstT: int; | ||||
| @@ -46,7 +46,7 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||||
| } | } | ||||
| output->SetFormat(input->GetFormat()); | output->SetFormat(input->GetFormat()); | ||||
| output->set_shape(input->shape()); | output->set_shape(input->shape()); | ||||
| output->set_data_type(input->data_type()); | |||||
| output->set_data_type(TypeId::kNumberTypeFloat32); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -37,7 +37,7 @@ constexpr uint32_t kNHWC_w_index = 2; | |||||
| constexpr uint32_t kNHWC_c_index = 3; | constexpr uint32_t kNHWC_c_index = 3; | ||||
| constexpr uint32_t kDimension_4d = 4; | constexpr uint32_t kDimension_4d = 4; | ||||
| const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32}; | |||||
| const std::set<int> kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeFloat16}; | |||||
| class Primitive { | class Primitive { | ||||
| public: | public: | ||||
| @@ -27,6 +27,7 @@ using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using mindspore::schema::PrimitiveType_Cast; | using mindspore::schema::PrimitiveType_Cast; | ||||
| using mindspore::schema::PrimitiveType_Fp16Cast; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| namespace { | namespace { | ||||
| @@ -87,6 +88,10 @@ int CastCPUKernel::DoCast(int thread_id) { | |||||
| Int32ToFloat32(reinterpret_cast<int32_t *>(input->Data()) + offset, | Int32ToFloat32(reinterpret_cast<int32_t *>(input->Data()) + offset, | ||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | reinterpret_cast<float *>(output_data) + offset, data_num); | ||||
| break; | break; | ||||
| case kNumberTypeFloat16: | |||||
| Fp16ToFloat32(reinterpret_cast<int16_t *>(input->Data()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| break; | |||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupport input data type " << input_data_type; | MS_LOG(ERROR) << "Unsupport input data type " << input_data_type; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -139,4 +144,5 @@ kernel::LiteKernel *CpuCastFp32KernelCreator(const std::vector<lite::tensor::Ten | |||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cast, CpuCastFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cast, CpuCastFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Fp16Cast, CpuCastFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -265,3 +265,4 @@ void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "nnacl/fp32/cast.h" | #include "nnacl/fp32/cast.h" | ||||
| #include "nnacl/fp32/common_func.h" | |||||
| void Uint8ToFloat32(const uint8_t *input, float *output, int number) { | void Uint8ToFloat32(const uint8_t *input, float *output, int number) { | ||||
| for (int i = 0; i < number; ++i) { | for (int i = 0; i < number; ++i) { | ||||
| @@ -40,6 +41,12 @@ void Int32ToFloat32(const int32_t *input, float *output, int number) { | |||||
| } | } | ||||
| } | } | ||||
| void Fp16ToFloat32(const int16_t *input, float *output, int number) { | |||||
| for (int i = 0; i < number; ++i) { | |||||
| output[i] = ShortToFloat32(input[i]); | |||||
| } | |||||
| } | |||||
| void Float32ToInt32(const float *input, int32_t *output, int number) { | void Float32ToInt32(const float *input, int32_t *output, int number) { | ||||
| for (int i = 0; i < number; ++i) { | for (int i = 0; i < number; ++i) { | ||||
| output[i] = (int32_t)input[i]; | output[i] = (int32_t)input[i]; | ||||
| @@ -35,6 +35,7 @@ void Uint8ToFloat32(const uint8_t *input, float *output, int number); | |||||
| void Uint8ToInt8(const uint8_t *input, int8_t *output, int number); | void Uint8ToInt8(const uint8_t *input, int8_t *output, int number); | ||||
| void Int8ToUint8(const int8_t *input, uint8_t *output, int number); | void Int8ToUint8(const int8_t *input, uint8_t *output, int number); | ||||
| void Int32ToFloat32(const int32_t *input, float *output, int number); | void Int32ToFloat32(const int32_t *input, float *output, int number); | ||||
| void Fp16ToFloat32(const int16_t *input, float *output, int number); | |||||
| void Float32ToInt32(const float *input, int32_t *output, int number); | void Float32ToInt32(const float *input, int32_t *output, int number); | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -113,3 +113,124 @@ void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bi | |||||
| PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, stride, is_relu, is_relu6, C8NUM); | PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, stride, is_relu, is_relu6, C8NUM); | ||||
| return; | return; | ||||
| } | } | ||||
| static const unsigned int FP32_BIT_SIZE = 32; | |||||
| static const unsigned int FP32_EXPONENT_BIAS = 127; | |||||
| static const unsigned int FP32_SIGNIFICAND = 23; | |||||
| static const unsigned int FP32_EXPONENT_MAX = 255; | |||||
| static const unsigned int FP16_BIT_SIZE = 16; | |||||
| static const unsigned int FP16_EXPONENT_BIAS = 15; | |||||
| static const unsigned int FP16_SIGNIFICAND = 10; | |||||
| static const int FP16_EXPONENT_MAX = 30; | |||||
| static const int FP16_EXPONENT_MIN = -10; | |||||
| float ShortToFloat32(int16_t srcValue) { | |||||
| uint16_t expHalf16 = srcValue & 0x7C00; | |||||
| int exp1 = (int)(expHalf16); | |||||
| uint16_t mantissa16 = srcValue & 0x03FF; | |||||
| int mantissa1 = (int)(mantissa16); | |||||
| int sign = (int)(srcValue & 0x8000); | |||||
| sign = sign << FP16_BIT_SIZE; | |||||
| // nan or inf | |||||
| if (expHalf16 == 0x7C00) { | |||||
| // nan | |||||
| if (mantissa16 > 0) { | |||||
| int res = (0x7FC00000 | sign); | |||||
| int *iRes = &res; | |||||
| auto fres = (float)(*iRes); | |||||
| return fres; | |||||
| } | |||||
| // inf | |||||
| int res = (0x7F800000 | sign); | |||||
| int *iRes = &res; | |||||
| auto fres = (float)(*iRes); | |||||
| return fres; | |||||
| } | |||||
| if (expHalf16 != 0) { | |||||
| exp1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS) << FP16_SIGNIFICAND); // exponents converted to float32 bias | |||||
| int res = (exp1 | mantissa1); | |||||
| res = res << (FP32_SIGNIFICAND - FP16_SIGNIFICAND); | |||||
| res = (res | sign); | |||||
| int *iRes = &res; | |||||
| auto fres = (float)(*iRes); | |||||
| return fres; | |||||
| } | |||||
| int xmm1 = exp1 > (1 << FP16_SIGNIFICAND) ? exp1 : (1 << FP16_SIGNIFICAND); | |||||
| xmm1 = (xmm1 << (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); | |||||
| xmm1 += ((FP32_EXPONENT_BIAS - FP16_EXPONENT_BIAS - FP16_SIGNIFICAND) | |||||
| << FP32_SIGNIFICAND); // add the bias difference to xmm1 | |||||
| xmm1 = xmm1 | sign; // Combine with the sign mask | |||||
| auto res = (float)(mantissa1); // Convert mantissa to float | |||||
| int *ixmm1 = NULL; | |||||
| ixmm1 = &xmm1; | |||||
| res *= (float)(*ixmm1); | |||||
| return res; | |||||
| } | |||||
| // __gnu_f2h_ieee | |||||
| int16_t Float32ToShort(float srcValue) { | |||||
| float *psrcValue = NULL; | |||||
| psrcValue = &srcValue; | |||||
| auto srcValueBit = (unsigned int)(*psrcValue); | |||||
| int sign = srcValueBit >> (FP32_BIT_SIZE - 1); | |||||
| int mantissa = srcValueBit & 0x007FFFFF; | |||||
| // exponent | |||||
| int exp = ((srcValueBit & 0x7F800000) >> FP32_SIGNIFICAND) + FP16_EXPONENT_BIAS - FP32_EXPONENT_BIAS; | |||||
| int16_t res; | |||||
| if (exp > 0 && exp < FP16_EXPONENT_MAX) { | |||||
| // use rte rounding mode, round the significand, combine sign, exponent and significand into a short. | |||||
| res = (sign << (FP16_BIT_SIZE - 1)) | (exp << FP16_SIGNIFICAND) | | |||||
| ((mantissa + 0x00001000) >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); | |||||
| } else if (srcValueBit == 0) { | |||||
| res = 0; | |||||
| } else { | |||||
| if (exp <= 0) { | |||||
| if (exp < FP16_EXPONENT_MIN) { | |||||
| // value is less than min half float point | |||||
| res = 0; | |||||
| } else { | |||||
| // normalized single, magnitude is less than min normal half float point. | |||||
| mantissa = (mantissa | 0x00800000) >> (1 - exp); | |||||
| // round to nearest | |||||
| if ((mantissa & 0x00001000) > 0) { | |||||
| mantissa = mantissa + 0x00002000; | |||||
| } | |||||
| // combine sign & mantissa (exp is zero to get denormalized number) | |||||
| res = (sign << FP16_EXPONENT_BIAS) | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); | |||||
| } | |||||
| } else if (exp == (FP32_EXPONENT_MAX - FP32_EXPONENT_BIAS + FP16_EXPONENT_BIAS)) { | |||||
| if (mantissa == 0) { | |||||
| // input float is infinity, return infinity half | |||||
| res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; | |||||
| } else { | |||||
| // input float is NaN, return half NaN | |||||
| res = (sign << FP16_EXPONENT_BIAS) | 0x7C00 | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); | |||||
| } | |||||
| } else { | |||||
| // exp > 0, normalized single, round to nearest | |||||
| if ((mantissa & 0x00001000) > 0) { | |||||
| mantissa = mantissa + 0x00002000; | |||||
| if ((mantissa & 0x00800000) > 0) { | |||||
| mantissa = 0; | |||||
| exp = exp + 1; | |||||
| } | |||||
| } | |||||
| if (exp > FP16_EXPONENT_MAX) { | |||||
| // exponent overflow - return infinity half | |||||
| res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; | |||||
| } else { | |||||
| // combine sign, exp and mantissa into normalized half | |||||
| res = (sign << FP16_EXPONENT_BIAS) | (exp << FP16_SIGNIFICAND) | | |||||
| (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); | |||||
| } | |||||
| } | |||||
| } | |||||
| return res; | |||||
| } | |||||
| @@ -37,6 +37,9 @@ void MatrixSub(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stri | |||||
| size_t row, size_t col); | size_t row, size_t col); | ||||
| void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr, size_t row, size_t col, | void MatrixMultiAdd(float *c11, float *c12, float *c21, float *c22, float *x_ptr, size_t row, size_t col, | ||||
| size_t c_stride, size_t x_stride); | size_t c_stride, size_t x_stride); | ||||
| int16_t Float32ToShort(float srcValue); | |||||
| float ShortToFloat32(int16_t srcValue); | |||||
| #ifdef ENABLE_ARM | #ifdef ENABLE_ARM | ||||
| void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | ||||
| @@ -14,12 +14,11 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef LITE_TFLITE_CAST_PARSER_H | |||||
| #ifndef LITE_TFLITE_CAST_PARSER_ | |||||
| #define LITE_TFLITE_CAST_PARSER_H | #define LITE_TFLITE_CAST_PARSER_H | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | #include "tools/converter/parser/tflite/tflite_node_parser.h" | ||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | ||||
| @@ -35,19 +34,6 @@ class TfliteCastParser : public TfliteNodeParser { | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tflite_opset, schema::CNodeT *op, | ||||
| TensorCache *tensor_cache, | TensorCache *tensor_cache, | ||||
| bool quantized_model) override; | bool quantized_model) override; | ||||
| private: | |||||
| std::map<int, TypeId> dtype_map = { | |||||
| {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, | |||||
| {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, | |||||
| {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, | |||||
| {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, | |||||
| {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, | |||||
| {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, | |||||
| {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, | |||||
| {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, | |||||
| {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, | |||||
| }; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tflite/tflite_dequantize_parser.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/common/node_util.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, | |||||
| schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) { | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; | |||||
| std::unique_ptr<schema::CastT> attr(new schema::CastT); | |||||
| // get the dequantize input tensor | |||||
| const auto &in_tensor = tfliteTensors[tfliteOp->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "weight_tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->srcT = dtype_map[in_tensor->type]; | |||||
| const auto &out_tensor = tfliteTensors[tfliteOp->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->dstT = dtype_map[out_tensor->type]; | |||||
| std::vector<tflite::TensorT *> weight_tensors{in_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Fp16Cast; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return 0; | |||||
| } | |||||
| TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_TFLITE_DEQUANTIZE_PARSER_H | |||||
| #define LITE_TFLITE_DEQUANTIZE_PARSER_H | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser.h" | |||||
| #include "tools/converter/parser/tflite/tflite_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TfliteDequantizeParser : public TfliteNodeParser { | |||||
| public: | |||||
| TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} | |||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | |||||
| const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op, | |||||
| TensorCache *tensor_cache, bool quantizedModel) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_TFLITE_DEQUANTIZE_PARSER_H | |||||
| @@ -208,10 +208,12 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_ | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed"; | |||||
| return RET_ERROR; | |||||
| if (quantType != schema::QuantType_QUANT_NONE) { | |||||
| status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | } | ||||
| subGraph->nodes.emplace_back(std::move(op)); | subGraph->nodes.emplace_back(std::move(op)); | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | |||||
| #include <memory> | #include <memory> | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| @@ -126,6 +127,17 @@ class TfliteNodeParser { | |||||
| protected: | protected: | ||||
| const std::string &name; | const std::string &name; | ||||
| std::map<int, TypeId> dtype_map = { | |||||
| {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, | |||||
| {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, | |||||
| {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, | |||||
| {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, | |||||
| {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, | |||||
| {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, | |||||
| {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, | |||||
| {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, | |||||
| {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, | |||||
| }; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||