diff --git a/mindspore/lite/nnacl/fp32/cast.c b/mindspore/lite/nnacl/fp32/cast.c index c98922cc01..2b3d4e7274 100644 --- a/mindspore/lite/nnacl/fp32/cast.c +++ b/mindspore/lite/nnacl/fp32/cast.c @@ -17,6 +17,12 @@ #include "nnacl/fp32/cast.h" #include "nnacl/fp32/common_func.h" +void BoolToFloat32(const bool *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + void Uint8ToFloat32(const uint8_t *input, float *output, int number) { for (int i = 0; i < number; ++i) { output[i] = (float)input[i]; diff --git a/mindspore/lite/nnacl/fp32/cast.h b/mindspore/lite/nnacl/fp32/cast.h index 5f4d12383a..23c5e851bf 100644 --- a/mindspore/lite/nnacl/fp32/cast.h +++ b/mindspore/lite/nnacl/fp32/cast.h @@ -31,6 +31,7 @@ typedef struct CastParameter { #ifdef __cplusplus extern "C" { #endif +void BoolToFloat32(const bool *input, float *output, int number); void Uint8ToFloat32(const uint8_t *input, float *output, int number); void Uint8ToInt8(const uint8_t *input, int8_t *output, int number); void Int8ToUint8(const int8_t *input, uint8_t *output, int number); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc index 41bd500fbf..e2f0c8e201 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -82,6 +82,9 @@ int CastCPUKernel::DoCast(int thread_id) { } } else { switch (input_data_type) { + case kNumberTypeBool: + BoolToFloat32(reinterpret_cast(input->MutableData()) + offset, + reinterpret_cast(output_data) + offset, data_num); case kNumberTypeUInt8: Uint8ToFloat32(reinterpret_cast(input->MutableData()) + offset, reinterpret_cast(output_data) + offset, data_num); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc index 383474316c..da045a9bb7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc @@ -46,9 +46,6 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu return RET_NULL_PTR; } attr->srcT = GetTfliteDataType(in_tensor->type); - if (attr->srcT == TypeId::kNumberTypeBool) { - attr->srcT = TypeId::kNumberTypeUInt8; - } const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { MS_LOG(ERROR) << "tensor is null";