diff --git a/mindspore/lite/nnacl/fp32/cast.c b/mindspore/lite/nnacl/fp32/cast.c index 2b3d4e7274..4e8f13225c 100644 --- a/mindspore/lite/nnacl/fp32/cast.c +++ b/mindspore/lite/nnacl/fp32/cast.c @@ -64,3 +64,15 @@ void Float32ToInt32(const float *input, int32_t *output, int number) { output[i] = (int32_t)input[i]; } } + +void Float32ToInt64(const float *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} + +void Int32ToInt64(const int32_t *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} diff --git a/mindspore/lite/nnacl/fp32/cast.h b/mindspore/lite/nnacl/fp32/cast.h index 23c5e851bf..4923e2e78d 100644 --- a/mindspore/lite/nnacl/fp32/cast.h +++ b/mindspore/lite/nnacl/fp32/cast.h @@ -39,6 +39,8 @@ void Int32ToFloat32(const int32_t *input, float *output, int number); void Fp16ToFloat32(const uint16_t *input, float *output, int number); void Float32ToFp16(const float *input, uint16_t *output, int number); void Float32ToInt32(const float *input, int32_t *output, int number); +void Float32ToInt64(const float *input, int64_t *output, int number); +void Int32ToInt64(const int32_t *input, int64_t *output, int number); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc index 1053d6dd7f..cf07155ea1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -71,13 +71,18 @@ int CastCPUKernel::DoCast(int thread_id) { auto input_data_type = input->data_type(); auto output_data_type = output->data_type(); if (output_data_type != kNumberTypeFloat32) { - if (input_data_type == kNumberTypeFloat32 && - (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { + if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { + Float32ToInt64(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) { Float32ToInt32(reinterpret_cast(input->data_c()) + offset, reinterpret_cast(output_data) + offset, data_num); } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeFloat16) { Float32ToFp16(reinterpret_cast(input->data_c()) + offset, reinterpret_cast(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { + Int32ToInt64(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); } else if (input_data_type == kNumberTypeInt32 && (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { memcpy(reinterpret_cast(output_data) + offset, reinterpret_cast(input->data_c()) + offset, diff --git a/mindspore/lite/test/models_onnx.cfg b/mindspore/lite/test/models_onnx.cfg index 692086e3b6..b106a50983 100644 --- a/mindspore/lite/test/models_onnx.cfg +++ b/mindspore/lite/test/models_onnx.cfg @@ -14,3 +14,5 @@ crnn_lite_lstm_v2.onnx;32,32,32,1 psenet_lite_mbv2.onnx;1,32,32,3 super-resolution-10.onnx;1,224,224,1 tinyyolov2-8.onnx;1,416,416,3 +ml_2012_ocr_cn.onnx +ml_2012_ocr_cn_noLSTM.onnx