| @@ -17,6 +17,12 @@ | |||||
| #include "nnacl/fp32/cast.h" | #include "nnacl/fp32/cast.h" | ||||
| #include "nnacl/fp32/common_func.h" | #include "nnacl/fp32/common_func.h" | ||||
| void BoolToFloat32(const bool *input, float *output, int number) { | |||||
| for (int i = 0; i < number; ++i) { | |||||
| output[i] = (float)input[i]; | |||||
| } | |||||
| } | |||||
| void Uint8ToFloat32(const uint8_t *input, float *output, int number) { | void Uint8ToFloat32(const uint8_t *input, float *output, int number) { | ||||
| for (int i = 0; i < number; ++i) { | for (int i = 0; i < number; ++i) { | ||||
| output[i] = (float)input[i]; | output[i] = (float)input[i]; | ||||
| @@ -31,6 +31,7 @@ typedef struct CastParameter { | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| void BoolToFloat32(const bool *input, float *output, int number); | |||||
| void Uint8ToFloat32(const uint8_t *input, float *output, int number); | void Uint8ToFloat32(const uint8_t *input, float *output, int number); | ||||
| void Uint8ToInt8(const uint8_t *input, int8_t *output, int number); | void Uint8ToInt8(const uint8_t *input, int8_t *output, int number); | ||||
| void Int8ToUint8(const int8_t *input, uint8_t *output, int number); | void Int8ToUint8(const int8_t *input, uint8_t *output, int number); | ||||
| @@ -82,6 +82,9 @@ int CastCPUKernel::DoCast(int thread_id) { | |||||
| } | } | ||||
| } else { | } else { | ||||
| switch (input_data_type) { | switch (input_data_type) { | ||||
| case kNumberTypeBool: | |||||
| BoolToFloat32(reinterpret_cast<bool *>(input->MutableData()) + offset, | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||||
| case kNumberTypeUInt8: | case kNumberTypeUInt8: | ||||
| Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset, | Uint8ToFloat32(reinterpret_cast<uint8_t *>(input->MutableData()) + offset, | ||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | reinterpret_cast<float *>(output_data) + offset, data_num); | ||||
| @@ -46,9 +46,6 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | attr->srcT = GetTfliteDataType(in_tensor->type); | ||||
| if (attr->srcT == TypeId::kNumberTypeBool) { | |||||
| attr->srcT = TypeId::kNumberTypeUInt8; | |||||
| } | |||||
| const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; | const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; | ||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | MS_LOG(ERROR) << "tensor is null"; | ||||