diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc index 681d2e551e..1053d6dd7f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -71,7 +71,8 @@ int CastCPUKernel::DoCast(int thread_id) { auto input_data_type = input->data_type(); auto output_data_type = output->data_type(); if (output_data_type != kNumberTypeFloat32) { - if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) { + if (input_data_type == kNumberTypeFloat32 && + (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { Float32ToInt32(reinterpret_cast(input->data_c()) + offset, reinterpret_cast(output_data) + offset, data_num); } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeFloat16) { @@ -81,10 +82,6 @@ int CastCPUKernel::DoCast(int thread_id) { (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { memcpy(reinterpret_cast(output_data) + offset, reinterpret_cast(input->data_c()) + offset, data_num * sizeof(int32_t)); - } else if (input_data_type == kNumberTypeFloat32 && - (output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) { - memcpy(reinterpret_cast(output_data) + offset, reinterpret_cast(input->data_c()) + offset, - data_num * sizeof(float)); } else { MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; return RET_ERROR; diff --git a/mindspore/lite/test/models_tflite_awaretraining.cfg b/mindspore/lite/test/models_tflite_awaretraining.cfg index b4645eeff7..8fa71106b9 100644 --- a/mindspore/lite/test/models_tflite_awaretraining.cfg +++ b/mindspore/lite/test/models_tflite_awaretraining.cfg @@ -36,3 +36,4 @@ vision_classifier_fungi_mobile_V1_1_default_1.tflite detect.tflite ssd_mobilenet_v1_1_default_1.tflite object_detection_mobile_object_localizer_v1_1_default_1.tflite +gts_detect_0730_quant_frozen.tflite